Skip to main content

datafusion_spark/function/math/
width_bucket.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
23    IntervalYearMonthArray,
24};
25use arrow::datatypes::DataType;
26use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval};
27use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth};
28use datafusion_common::cast::{
29    as_duration_microsecond_array, as_float64_array, as_int64_array,
30    as_interval_mdn_array, as_interval_ym_array,
31};
32use datafusion_common::types::{
33    NativeType, logical_duration_microsecond, logical_float64, logical_int64,
34    logical_interval_mdn, logical_interval_year_month,
35};
36use datafusion_common::{Result, exec_err, internal_err};
37use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
38use datafusion_expr::{
39    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
40    TypeSignatureClass,
41};
42use datafusion_functions::utils::make_scalar_function;
43
44use arrow::array::{Int32Array, Int32Builder, Int64Array};
45use arrow::datatypes::TimeUnit::Microsecond;
46use datafusion_expr::Coercion;
47use datafusion_expr::Volatility::Immutable;
48
49#[derive(Debug, PartialEq, Eq, Hash)]
50pub struct SparkWidthBucket {
51    signature: Signature,
52}
53
54impl Default for SparkWidthBucket {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl SparkWidthBucket {
61    pub fn new() -> Self {
62        let numeric = Coercion::new_implicit(
63            TypeSignatureClass::Native(logical_float64()),
64            vec![TypeSignatureClass::Numeric],
65            NativeType::Float64,
66        );
67        let duration = Coercion::new_implicit(
68            TypeSignatureClass::Native(logical_duration_microsecond()),
69            vec![TypeSignatureClass::Duration],
70            NativeType::Duration(Microsecond),
71        );
72        let interval_ym = Coercion::new_exact(TypeSignatureClass::Native(
73            logical_interval_year_month(),
74        ));
75        let interval_mdn =
76            Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn()));
77        let bucket = Coercion::new_implicit(
78            TypeSignatureClass::Native(logical_int64()),
79            vec![TypeSignatureClass::Integer],
80            NativeType::Int64,
81        );
82        let type_signature = Signature::one_of(
83            vec![
84                TypeSignature::Coercible(vec![
85                    numeric.clone(),
86                    numeric.clone(),
87                    numeric.clone(),
88                    bucket.clone(),
89                ]),
90                TypeSignature::Coercible(vec![
91                    duration.clone(),
92                    duration.clone(),
93                    duration.clone(),
94                    bucket.clone(),
95                ]),
96                TypeSignature::Coercible(vec![
97                    interval_ym.clone(),
98                    interval_ym.clone(),
99                    interval_ym.clone(),
100                    bucket.clone(),
101                ]),
102                TypeSignature::Coercible(vec![
103                    interval_mdn.clone(),
104                    interval_mdn.clone(),
105                    interval_mdn.clone(),
106                    bucket.clone(),
107                ]),
108            ],
109            Immutable,
110        )
111        .with_parameter_names(vec!["expr", "min", "max", "num_buckets"])
112        .expect("valid parameter names");
113        Self {
114            signature: type_signature,
115        }
116    }
117}
118
119impl ScalarUDFImpl for SparkWidthBucket {
120    fn as_any(&self) -> &dyn Any {
121        self
122    }
123
124    fn name(&self) -> &str {
125        "width_bucket"
126    }
127
128    fn signature(&self) -> &Signature {
129        &self.signature
130    }
131
132    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
133        Ok(Int32)
134    }
135
136    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
137        make_scalar_function(width_bucket_kern, vec![])(&args.args)
138    }
139
140    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
141        if input.len() == 1 {
142            let value = &input[0];
143            Ok(value.sort_properties)
144        } else {
145            Ok(SortProperties::default())
146        }
147    }
148}
149
150fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
151    let [v, minv, maxv, nb] = args else {
152        return exec_err!(
153            "width_bucket expects exactly 4 argument, got {}",
154            args.len()
155        );
156    };
157
158    match v.data_type() {
159        Float64 => {
160            let v = as_float64_array(v)?;
161            let min = as_float64_array(minv)?;
162            let max = as_float64_array(maxv)?;
163            let n_bucket = as_int64_array(nb)?;
164            Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket)))
165        }
166        Duration(Microsecond) => {
167            let v = as_duration_microsecond_array(v)?;
168            let min = as_duration_microsecond_array(minv)?;
169            let max = as_duration_microsecond_array(maxv)?;
170            let n_bucket = as_int64_array(nb)?;
171            Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket)))
172        }
173        Interval(YearMonth) => {
174            let v = as_interval_ym_array(v)?;
175            let min = as_interval_ym_array(minv)?;
176            let max = as_interval_ym_array(maxv)?;
177            let n_bucket = as_int64_array(nb)?;
178            Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket)))
179        }
180        Interval(MonthDayNano) => {
181            let v = as_interval_mdn_array(v)?;
182            let min = as_interval_mdn_array(minv)?;
183            let max = as_interval_mdn_array(maxv)?;
184            let n_bucket = as_int64_array(nb)?;
185            Ok(Arc::new(width_bucket_interval_mdn_exact(
186                v, min, max, n_bucket,
187            )))
188        }
189
190        other => internal_err!(
191            "width_bucket received unexpected data types: {:?}, {:?}, {:?}, {:?}",
192            other,
193            minv.data_type(),
194            maxv.data_type(),
195            nb.data_type()
196        ),
197    }
198}
199
200macro_rules! width_bucket_kernel_impl {
201    ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => {
202        pub(crate) fn $name(
203            v: &$arr_ty,
204            min: &$arr_ty,
205            max: &$arr_ty,
206            n_bucket: &Int64Array,
207        ) -> Int32Array {
208            let len = v.len();
209            let mut b = Int32Builder::with_capacity(len);
210
211            for i in 0..len {
212                if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i)
213                {
214                    b.append_null();
215                    continue;
216                }
217                let x = ($to_f64)(v, i);
218                let l = ($to_f64)(min, i);
219                let h = ($to_f64)(max, i);
220                let buckets = n_bucket.value(i);
221
222                if buckets <= 0 {
223                    b.append_null();
224                    continue;
225                }
226                let next_bucket = (buckets + 1) as i32;
227                if $check_nan {
228                    if !x.is_finite() || !l.is_finite() || !h.is_finite() {
229                        b.append_null();
230                        continue;
231                    }
232                }
233
234                let ord = match l.partial_cmp(&h) {
235                    Some(o) => o,
236                    None => {
237                        b.append_null();
238                        continue;
239                    }
240                };
241                if ord == std::cmp::Ordering::Equal {
242                    b.append_null();
243                    continue;
244                }
245                let asc = ord == std::cmp::Ordering::Less;
246
247                if asc {
248                    if x < l {
249                        b.append_value(0);
250                        continue;
251                    }
252                    if x >= h {
253                        b.append_value(next_bucket);
254                        continue;
255                    }
256                } else {
257                    if x > l {
258                        b.append_value(0);
259                        continue;
260                    }
261                    if x <= h {
262                        b.append_value(next_bucket);
263                        continue;
264                    }
265                }
266
267                let width = (h - l) / (buckets as f64);
268                if width == 0.0 || !width.is_finite() {
269                    b.append_null();
270                    continue;
271                }
272                let mut bucket = ((x - l) / width).floor() as i32 + 1;
273                if bucket < 1 {
274                    bucket = 1;
275                }
276                if bucket > next_bucket {
277                    bucket = next_bucket;
278                }
279
280                b.append_value(bucket);
281            }
282
283            b.finish()
284        }
285    };
286}
287
288width_bucket_kernel_impl!(
289    width_bucket_float64,
290    Float64Array,
291    |arr: &Float64Array, i: usize| arr.value(i),
292    true
293);
294
295width_bucket_kernel_impl!(
296    width_bucket_i64_as_float,
297    DurationMicrosecondArray,
298    |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64,
299    false
300);
301
302width_bucket_kernel_impl!(
303    width_bucket_i32_as_float,
304    IntervalYearMonthArray,
305    |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64,
306    false
307);
308const NS_PER_DAY_I128: i128 = 86_400_000_000_000;
309pub(crate) fn width_bucket_interval_mdn_exact(
310    v: &IntervalMonthDayNanoArray,
311    lo: &IntervalMonthDayNanoArray,
312    hi: &IntervalMonthDayNanoArray,
313    n: &Int64Array,
314) -> Int32Array {
315    let len = v.len();
316    let mut b = Int32Builder::with_capacity(len);
317
318    for i in 0..len {
319        if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) {
320            b.append_null();
321            continue;
322        }
323        let buckets = n.value(i);
324        if buckets <= 0 {
325            b.append_null();
326            continue;
327        }
328        let next_bucket = (buckets + 1) as i32;
329
330        let x = v.value(i);
331        let l = lo.value(i);
332        let h = hi.value(i);
333
334        // asc/desc
335        // Values of IntervalMonthDayNano are compared using their binary representation, which can lead to surprising results.
336        let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds);
337        if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) {
338            b.append_null();
339            continue;
340        }
341
342        // ------------------- only month -------------------
343        if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months {
344            let x_m = x.months as f64;
345            let l_m = l.months as f64;
346            let h_m = h.months as f64;
347
348            if asc {
349                if x_m < l_m {
350                    b.append_value(0);
351                    continue;
352                }
353                if x_m >= h_m {
354                    b.append_value(next_bucket);
355                    continue;
356                }
357            } else {
358                if x_m > l_m {
359                    b.append_value(0);
360                    continue;
361                }
362                if x_m <= h_m {
363                    b.append_value(next_bucket);
364                    continue;
365                }
366            }
367
368            let width = (h_m - l_m) / (buckets as f64);
369            if width == 0.0 || !width.is_finite() {
370                b.append_null();
371                continue;
372            }
373
374            let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1;
375            if bucket < 1 {
376                bucket = 1;
377            }
378            if bucket > next_bucket {
379                bucket = next_bucket;
380            }
381            b.append_value(bucket);
382            continue;
383        }
384
385        // ---------------  months equals -------------------
386        if l.months == h.months {
387            let base_days = l.days as i128;
388            let base_ns = l.nanoseconds as i128;
389
390            let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128
391                + (x.nanoseconds as i128 - base_ns);
392            let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128
393                + (h.nanoseconds as i128 - base_ns);
394
395            let x_f = xf as f64;
396            let l_f = 0.0;
397            let h_f = hf as f64;
398
399            if asc {
400                if x_f < l_f {
401                    b.append_value(0);
402                    continue;
403                }
404                if x_f >= h_f {
405                    b.append_value(next_bucket);
406                    continue;
407                }
408            } else {
409                if x_f > l_f {
410                    b.append_value(0);
411                    continue;
412                }
413                if x_f <= h_f {
414                    b.append_value(next_bucket);
415                    continue;
416                }
417            }
418
419            let width = (h_f - l_f) / (buckets as f64);
420            if width == 0.0 || !width.is_finite() {
421                b.append_null();
422                continue;
423            }
424
425            let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1;
426            if bucket < 1 {
427                bucket = 1;
428            }
429            if bucket > next_bucket {
430                bucket = next_bucket;
431            }
432            b.append_value(bucket);
433            continue;
434        }
435
436        b.append_null();
437    }
438
439    b.finish()
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use std::sync::Arc;
446
447    use arrow::array::{
448        ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, Int64Array,
449        IntervalYearMonthArray,
450    };
451    use arrow::datatypes::IntervalMonthDayNano;
452
453    // --- Helpers -------------------------------------------------------------
454
455    fn i64_array_all(len: usize, val: i64) -> Arc<Int64Array> {
456        Arc::new(Int64Array::from(vec![val; len]))
457    }
458
459    fn f64_array(vals: &[f64]) -> Arc<Float64Array> {
460        Arc::new(Float64Array::from(vals.to_vec()))
461    }
462
463    fn f64_array_opt(vals: &[Option<f64>]) -> Arc<Float64Array> {
464        Arc::new(Float64Array::from(vals.to_vec()))
465    }
466
467    fn dur_us_array(vals: &[i64]) -> Arc<DurationMicrosecondArray> {
468        Arc::new(DurationMicrosecondArray::from(vals.to_vec()))
469    }
470
471    fn ym_array(vals: &[i32]) -> Arc<IntervalYearMonthArray> {
472        Arc::new(IntervalYearMonthArray::from(vals.to_vec()))
473    }
474
475    fn downcast_i32(arr: &ArrayRef) -> &Int32Array {
476        arr.as_any().downcast_ref::<Int32Array>().unwrap()
477    }
478
479    fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc<IntervalMonthDayNanoArray> {
480        let data: Vec<IntervalMonthDayNano> = vals
481            .iter()
482            .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns))
483            .collect();
484        Arc::new(IntervalMonthDayNanoArray::from(data))
485    }
486
487    // --- Float64 -------------------------------------------------------------
488
489    #[test]
490    fn test_width_bucket_f64_basic() {
491        let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]);
492        let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]);
493        let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]);
494        let n = i64_array_all(5, 10);
495
496        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
497        let out = downcast_i32(&out);
498        assert_eq!(out.values(), &[1, 2, 10, 0, 11]);
499    }
500
501    #[test]
502    fn test_width_bucket_f64_descending_range() {
503        let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]);
504        let lo = f64_array(&[10.0; 5]);
505        let hi = f64_array(&[0.0; 5]);
506        let n = i64_array_all(5, 10);
507
508        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
509        let out = downcast_i32(&out);
510
511        assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
512    }
513    #[test]
514    fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() {
515        let v = f64_array(&[0.0, 9.999999999, 10.0]);
516        let lo = f64_array(&[0.0; 3]);
517        let hi = f64_array(&[10.0; 3]);
518        let n = i64_array_all(3, 10);
519
520        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
521        let out = downcast_i32(&out);
522        assert_eq!(out.values(), &[1, 10, 11]);
523    }
524
525    #[test]
526    fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() {
527        let v = f64_array(&[10.0, 0.0, -0.000001]);
528        let lo = f64_array(&[10.0; 3]);
529        let hi = f64_array(&[0.0; 3]);
530        let n = i64_array_all(3, 10);
531
532        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
533        let out = downcast_i32(&out);
534        assert_eq!(out.values(), &[1, 11, 11]);
535    }
536
537    #[test]
538    fn test_width_bucket_f64_edge_cases() {
539        let v = f64_array(&[1.0, 5.0, 9.0]);
540        let lo = f64_array(&[0.0, 0.0, 0.0]);
541        let hi = f64_array(&[10.0, 10.0, 10.0]);
542        let n = Arc::new(Int64Array::from(vec![0, -1, 10]));
543        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
544        let out = downcast_i32(&out);
545        assert!(out.is_null(0));
546        assert!(out.is_null(1));
547        assert_eq!(out.value(2), 10);
548
549        let v = f64_array(&[1.0]);
550        let lo = f64_array(&[5.0]);
551        let hi = f64_array(&[5.0]);
552        let n = i64_array_all(1, 10);
553        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
554        let out = downcast_i32(&out);
555        assert!(out.is_null(0));
556
557        let v = f64_array_opt(&[Some(f64::NAN)]);
558        let lo = f64_array(&[0.0]);
559        let hi = f64_array(&[10.0]);
560        let n = i64_array_all(1, 10);
561        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
562        let out = downcast_i32(&out);
563        assert!(out.is_null(0));
564    }
565
566    #[test]
567    fn test_width_bucket_f64_nulls_propagate() {
568        let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]);
569        let lo = f64_array(&[0.0; 4]);
570        let hi = f64_array(&[10.0; 4]);
571        let n = i64_array_all(4, 10);
572
573        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
574        let out = downcast_i32(&out);
575        assert!(out.is_null(0));
576        assert_eq!(out.value(1), 2);
577        assert_eq!(out.value(2), 3);
578        assert_eq!(out.value(3), 4);
579
580        let v = f64_array(&[1.0]);
581        let lo = f64_array_opt(&[None]);
582        let hi = f64_array(&[10.0]);
583        let n = i64_array_all(1, 10);
584        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
585        let out = downcast_i32(&out);
586        assert!(out.is_null(0));
587    }
588
589    // --- Duration(Microsecond) ----------------------------------------------
590
591    #[test]
592    fn test_width_bucket_duration_us() {
593        let v = dur_us_array(&[1_000_000, 0, -1]);
594        let lo = dur_us_array(&[0, 0, 0]);
595        let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]);
596        let n = i64_array_all(3, 2);
597
598        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
599        let out = downcast_i32(&out);
600        assert_eq!(out.values(), &[2, 1, 0]);
601    }
602
603    #[test]
604    fn test_width_bucket_duration_us_equal_bounds() {
605        let v = dur_us_array(&[0]);
606        let lo = dur_us_array(&[1]);
607        let hi = dur_us_array(&[1]);
608        let n = i64_array_all(1, 10);
609        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
610        assert!(downcast_i32(&out).is_null(0));
611    }
612
613    // --- Interval(YearMonth) ------------------------------------------------
614
615    #[test]
616    fn test_width_bucket_interval_ym_basic() {
617        let v = ym_array(&[0, 5, 11, 12, 13]);
618        let lo = ym_array(&[0; 5]);
619        let hi = ym_array(&[12; 5]);
620        let n = i64_array_all(5, 12);
621
622        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
623        let out = downcast_i32(&out);
624        assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
625    }
626
627    #[test]
628    fn test_width_bucket_interval_ym_desc() {
629        let v = ym_array(&[11, 12, 0, -1, 13]);
630        let lo = ym_array(&[12; 5]);
631        let hi = ym_array(&[0; 5]);
632        let n = i64_array_all(5, 12);
633
634        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
635        let out = downcast_i32(&out);
636        assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
637    }
638
639    // --- Interval(MonthDayNano) --------------------------------------------
640
641    #[test]
642    fn test_width_bucket_interval_mdn_months_only_basic() {
643        let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]);
644        let lo = mdn_array(&[(0, 0, 0); 5]);
645        let hi = mdn_array(&[(12, 0, 0); 5]);
646        let n = i64_array_all(5, 12);
647
648        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
649        let out = downcast_i32(&out);
650        assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
651    }
652
653    #[test]
654    fn test_width_bucket_interval_mdn_months_only_desc() {
655        let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]);
656        let lo = mdn_array(&[(12, 0, 0); 5]);
657        let hi = mdn_array(&[(0, 0, 0); 5]);
658        let n = i64_array_all(5, 12);
659
660        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
661        let out = downcast_i32(&out);
662        // Mismo patrĂ³n que YM descendente
663        assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
664    }
665
666    #[test]
667    fn test_width_bucket_interval_mdn_day_nano_basic() {
668        let v = mdn_array(&[
669            (0, 0, 0),
670            (0, 5, 0),
671            (0, 9, 0),
672            (0, 10, 0),
673            (0, -1, 0),
674            (0, 11, 0),
675        ]);
676        let lo = mdn_array(&[(0, 0, 0); 6]);
677        let hi = mdn_array(&[(0, 10, 0); 6]);
678        let n = i64_array_all(6, 10);
679
680        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
681        let out = downcast_i32(&out);
682        // x==hi -> n+1, x<lo -> 0, x>hi -> n+1
683        assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]);
684    }
685
686    #[test]
687    fn test_width_bucket_interval_mdn_day_nano_desc() {
688        let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
689        let lo = mdn_array(&[(0, 10, 0); 5]);
690        let hi = mdn_array(&[(0, 0, 0); 5]);
691        let n = i64_array_all(5, 10);
692
693        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
694        let out = downcast_i32(&out);
695
696        assert_eq!(out.values(), &[2, 1, 11, 11, 0]);
697    }
698    #[test]
699    fn test_width_bucket_interval_mdn_day_nano_desc_inside() {
700        let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
701        let lo = mdn_array(&[(0, 10, 0); 5]);
702        let hi = mdn_array(&[(0, 0, 0); 5]);
703        let n = i64_array_all(5, 10);
704
705        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
706        let out = downcast_i32(&out);
707
708        assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
709    }
710
711    #[test]
712    fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() {
713        let v = mdn_array(&[(0, 1, 0)]);
714        let lo = mdn_array(&[(0, 0, 0)]);
715        let hi = mdn_array(&[(1, 1, 0)]);
716        let n = i64_array_all(1, 4);
717
718        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
719        let out = downcast_i32(&out);
720        assert!(out.is_null(0));
721    }
722
723    #[test]
724    fn test_width_bucket_interval_mdn_equal_bounds_is_null() {
725        let v = mdn_array(&[(0, 0, 0)]);
726        let lo = mdn_array(&[(1, 2, 3)]);
727        let hi = mdn_array(&[(1, 2, 3)]); // lo == hi
728        let n = i64_array_all(1, 10);
729
730        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
731        assert!(downcast_i32(&out).is_null(0));
732    }
733
734    #[test]
735    fn test_width_bucket_interval_mdn_invalid_n_is_null() {
736        let v = mdn_array(&[(0, 0, 0)]);
737        let lo = mdn_array(&[(0, 0, 0)]);
738        let hi = mdn_array(&[(0, 10, 0)]);
739        let n = Arc::new(Int64Array::from(vec![0])); // n <= 0
740
741        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
742        assert!(downcast_i32(&out).is_null(0));
743    }
744
745    #[test]
746    fn test_width_bucket_interval_mdn_nulls_propagate() {
747        let v = Arc::new(IntervalMonthDayNanoArray::from(vec![
748            None,
749            Some(IntervalMonthDayNano::new(0, 5, 0)),
750        ]));
751        let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]);
752        let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]);
753        let n = i64_array_all(2, 10);
754
755        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
756        let out = downcast_i32(&out);
757        assert!(out.is_null(0));
758        assert_eq!(out.value(1), 6);
759    }
760
761    // --- Errores -------------------------------------------------------------
762
763    #[test]
764    fn test_width_bucket_wrong_arg_count() {
765        let v = f64_array(&[1.0]);
766        let lo = f64_array(&[0.0]);
767        let hi = f64_array(&[10.0]);
768        let err = width_bucket_kern(&[v, lo, hi]).unwrap_err();
769        let msg = format!("{err}");
770        assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}");
771    }
772
773    #[test]
774    fn test_width_bucket_unsupported_type() {
775        let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
776        let lo = f64_array(&[0.0, 0.0, 0.0]);
777        let hi = f64_array(&[10.0, 10.0, 10.0]);
778        let n = i64_array_all(3, 10);
779
780        let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
781        let msg = format!("{err}");
782        assert!(
783            msg.contains("width_bucket received unexpected data types"),
784            "unexpected error: {msg}"
785        );
786    }
787}