Skip to main content

datafusion_spark/function/math/
width_bucket.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{
21    Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
22    IntervalYearMonthArray,
23};
24use arrow::datatypes::DataType;
25use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval};
26use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth};
27use datafusion_common::cast::{
28    as_duration_microsecond_array, as_float64_array, as_int64_array,
29    as_interval_mdn_array, as_interval_ym_array,
30};
31use datafusion_common::types::{
32    NativeType, logical_duration_microsecond, logical_float64, logical_int64,
33    logical_interval_mdn, logical_interval_year_month,
34};
35use datafusion_common::{Result, exec_err, internal_err};
36use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
37use datafusion_expr::{
38    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
39    TypeSignatureClass,
40};
41use datafusion_functions::utils::make_scalar_function;
42
43use arrow::array::{Int32Array, Int32Builder, Int64Array};
44use arrow::datatypes::TimeUnit::Microsecond;
45use datafusion_expr::Coercion;
46use datafusion_expr::Volatility::Immutable;
47
48#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct SparkWidthBucket {
50    signature: Signature,
51}
52
53impl Default for SparkWidthBucket {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl SparkWidthBucket {
60    pub fn new() -> Self {
61        let numeric = Coercion::new_implicit(
62            TypeSignatureClass::Native(logical_float64()),
63            vec![TypeSignatureClass::Numeric],
64            NativeType::Float64,
65        );
66        let duration = Coercion::new_implicit(
67            TypeSignatureClass::Native(logical_duration_microsecond()),
68            vec![TypeSignatureClass::Duration],
69            NativeType::Duration(Microsecond),
70        );
71        let interval_ym = Coercion::new_exact(TypeSignatureClass::Native(
72            logical_interval_year_month(),
73        ));
74        let interval_mdn =
75            Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn()));
76        let bucket = Coercion::new_implicit(
77            TypeSignatureClass::Native(logical_int64()),
78            vec![TypeSignatureClass::Integer],
79            NativeType::Int64,
80        );
81        let type_signature = Signature::one_of(
82            vec![
83                TypeSignature::Coercible(vec![
84                    numeric.clone(),
85                    numeric.clone(),
86                    numeric.clone(),
87                    bucket.clone(),
88                ]),
89                TypeSignature::Coercible(vec![
90                    duration.clone(),
91                    duration.clone(),
92                    duration.clone(),
93                    bucket.clone(),
94                ]),
95                TypeSignature::Coercible(vec![
96                    interval_ym.clone(),
97                    interval_ym.clone(),
98                    interval_ym.clone(),
99                    bucket.clone(),
100                ]),
101                TypeSignature::Coercible(vec![
102                    interval_mdn.clone(),
103                    interval_mdn.clone(),
104                    interval_mdn.clone(),
105                    bucket.clone(),
106                ]),
107            ],
108            Immutable,
109        )
110        .with_parameter_names(vec!["expr", "min", "max", "num_buckets"])
111        .expect("valid parameter names");
112        Self {
113            signature: type_signature,
114        }
115    }
116}
117
118impl ScalarUDFImpl for SparkWidthBucket {
119    fn name(&self) -> &str {
120        "width_bucket"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
128        Ok(Int32)
129    }
130
131    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
132        make_scalar_function(width_bucket_kern, vec![])(&args.args)
133    }
134
135    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
136        if input.len() == 1 {
137            let value = &input[0];
138            Ok(value.sort_properties)
139        } else {
140            Ok(SortProperties::default())
141        }
142    }
143}
144
145fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
146    let [v, minv, maxv, nb] = args else {
147        return exec_err!(
148            "width_bucket expects exactly 4 argument, got {}",
149            args.len()
150        );
151    };
152
153    match v.data_type() {
154        Float64 => {
155            let v = as_float64_array(v)?;
156            let min = as_float64_array(minv)?;
157            let max = as_float64_array(maxv)?;
158            let n_bucket = as_int64_array(nb)?;
159            Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket)))
160        }
161        Duration(Microsecond) => {
162            let v = as_duration_microsecond_array(v)?;
163            let min = as_duration_microsecond_array(minv)?;
164            let max = as_duration_microsecond_array(maxv)?;
165            let n_bucket = as_int64_array(nb)?;
166            Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket)))
167        }
168        Interval(YearMonth) => {
169            let v = as_interval_ym_array(v)?;
170            let min = as_interval_ym_array(minv)?;
171            let max = as_interval_ym_array(maxv)?;
172            let n_bucket = as_int64_array(nb)?;
173            Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket)))
174        }
175        Interval(MonthDayNano) => {
176            let v = as_interval_mdn_array(v)?;
177            let min = as_interval_mdn_array(minv)?;
178            let max = as_interval_mdn_array(maxv)?;
179            let n_bucket = as_int64_array(nb)?;
180            Ok(Arc::new(width_bucket_interval_mdn_exact(
181                v, min, max, n_bucket,
182            )))
183        }
184
185        other => internal_err!(
186            "width_bucket received unexpected data types: {:?}, {:?}, {:?}, {:?}",
187            other,
188            minv.data_type(),
189            maxv.data_type(),
190            nb.data_type()
191        ),
192    }
193}
194
195macro_rules! width_bucket_kernel_impl {
196    ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => {
197        pub(crate) fn $name(
198            v: &$arr_ty,
199            min: &$arr_ty,
200            max: &$arr_ty,
201            n_bucket: &Int64Array,
202        ) -> Int32Array {
203            let len = v.len();
204            let mut b = Int32Builder::with_capacity(len);
205
206            for i in 0..len {
207                if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i)
208                {
209                    b.append_null();
210                    continue;
211                }
212                let x = ($to_f64)(v, i);
213                let l = ($to_f64)(min, i);
214                let h = ($to_f64)(max, i);
215                let buckets = n_bucket.value(i);
216
217                if buckets <= 0 {
218                    b.append_null();
219                    continue;
220                }
221                let next_bucket = (buckets + 1) as i32;
222                if $check_nan {
223                    if !x.is_finite() || !l.is_finite() || !h.is_finite() {
224                        b.append_null();
225                        continue;
226                    }
227                }
228
229                let ord = match l.partial_cmp(&h) {
230                    Some(o) => o,
231                    None => {
232                        b.append_null();
233                        continue;
234                    }
235                };
236                if ord == std::cmp::Ordering::Equal {
237                    b.append_null();
238                    continue;
239                }
240                let asc = ord == std::cmp::Ordering::Less;
241
242                if asc {
243                    if x < l {
244                        b.append_value(0);
245                        continue;
246                    }
247                    if x >= h {
248                        b.append_value(next_bucket);
249                        continue;
250                    }
251                } else {
252                    if x > l {
253                        b.append_value(0);
254                        continue;
255                    }
256                    if x <= h {
257                        b.append_value(next_bucket);
258                        continue;
259                    }
260                }
261
262                let width = (h - l) / (buckets as f64);
263                if width == 0.0 || !width.is_finite() {
264                    b.append_null();
265                    continue;
266                }
267                let mut bucket = ((x - l) / width).floor() as i32 + 1;
268                if bucket < 1 {
269                    bucket = 1;
270                }
271                if bucket > next_bucket {
272                    bucket = next_bucket;
273                }
274
275                b.append_value(bucket);
276            }
277
278            b.finish()
279        }
280    };
281}
282
283width_bucket_kernel_impl!(
284    width_bucket_float64,
285    Float64Array,
286    |arr: &Float64Array, i: usize| arr.value(i),
287    true
288);
289
290width_bucket_kernel_impl!(
291    width_bucket_i64_as_float,
292    DurationMicrosecondArray,
293    |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64,
294    false
295);
296
297width_bucket_kernel_impl!(
298    width_bucket_i32_as_float,
299    IntervalYearMonthArray,
300    |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64,
301    false
302);
303const NS_PER_DAY_I128: i128 = 86_400_000_000_000;
304pub(crate) fn width_bucket_interval_mdn_exact(
305    v: &IntervalMonthDayNanoArray,
306    lo: &IntervalMonthDayNanoArray,
307    hi: &IntervalMonthDayNanoArray,
308    n: &Int64Array,
309) -> Int32Array {
310    let len = v.len();
311    let mut b = Int32Builder::with_capacity(len);
312
313    for i in 0..len {
314        if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) {
315            b.append_null();
316            continue;
317        }
318        let buckets = n.value(i);
319        if buckets <= 0 {
320            b.append_null();
321            continue;
322        }
323        let next_bucket = (buckets + 1) as i32;
324
325        let x = v.value(i);
326        let l = lo.value(i);
327        let h = hi.value(i);
328
329        // asc/desc
330        // Values of IntervalMonthDayNano are compared using their binary representation, which can lead to surprising results.
331        let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds);
332        if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) {
333            b.append_null();
334            continue;
335        }
336
337        // ------------------- only month -------------------
338        if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months {
339            let x_m = x.months as f64;
340            let l_m = l.months as f64;
341            let h_m = h.months as f64;
342
343            if asc {
344                if x_m < l_m {
345                    b.append_value(0);
346                    continue;
347                }
348                if x_m >= h_m {
349                    b.append_value(next_bucket);
350                    continue;
351                }
352            } else {
353                if x_m > l_m {
354                    b.append_value(0);
355                    continue;
356                }
357                if x_m <= h_m {
358                    b.append_value(next_bucket);
359                    continue;
360                }
361            }
362
363            let width = (h_m - l_m) / (buckets as f64);
364            if width == 0.0 || !width.is_finite() {
365                b.append_null();
366                continue;
367            }
368
369            let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1;
370            if bucket < 1 {
371                bucket = 1;
372            }
373            if bucket > next_bucket {
374                bucket = next_bucket;
375            }
376            b.append_value(bucket);
377            continue;
378        }
379
380        // ---------------  months equals -------------------
381        if l.months == h.months {
382            let base_days = l.days as i128;
383            let base_ns = l.nanoseconds as i128;
384
385            let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128
386                + (x.nanoseconds as i128 - base_ns);
387            let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128
388                + (h.nanoseconds as i128 - base_ns);
389
390            let x_f = xf as f64;
391            let l_f = 0.0;
392            let h_f = hf as f64;
393
394            if asc {
395                if x_f < l_f {
396                    b.append_value(0);
397                    continue;
398                }
399                if x_f >= h_f {
400                    b.append_value(next_bucket);
401                    continue;
402                }
403            } else {
404                if x_f > l_f {
405                    b.append_value(0);
406                    continue;
407                }
408                if x_f <= h_f {
409                    b.append_value(next_bucket);
410                    continue;
411                }
412            }
413
414            let width = (h_f - l_f) / (buckets as f64);
415            if width == 0.0 || !width.is_finite() {
416                b.append_null();
417                continue;
418            }
419
420            let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1;
421            if bucket < 1 {
422                bucket = 1;
423            }
424            if bucket > next_bucket {
425                bucket = next_bucket;
426            }
427            b.append_value(bucket);
428            continue;
429        }
430
431        b.append_null();
432    }
433
434    b.finish()
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    use arrow::array::{
442        ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, Int64Array,
443        IntervalYearMonthArray,
444    };
445    use arrow::datatypes::IntervalMonthDayNano;
446
447    // --- Helpers -------------------------------------------------------------
448
449    fn i64_array_all(len: usize, val: i64) -> Arc<Int64Array> {
450        Arc::new(Int64Array::from(vec![val; len]))
451    }
452
453    fn f64_array(vals: &[f64]) -> Arc<Float64Array> {
454        Arc::new(Float64Array::from(vals.to_vec()))
455    }
456
457    fn f64_array_opt(vals: &[Option<f64>]) -> Arc<Float64Array> {
458        Arc::new(Float64Array::from(vals.to_vec()))
459    }
460
461    fn dur_us_array(vals: &[i64]) -> Arc<DurationMicrosecondArray> {
462        Arc::new(DurationMicrosecondArray::from(vals.to_vec()))
463    }
464
465    fn ym_array(vals: &[i32]) -> Arc<IntervalYearMonthArray> {
466        Arc::new(IntervalYearMonthArray::from(vals.to_vec()))
467    }
468
469    fn downcast_i32(arr: &ArrayRef) -> &Int32Array {
470        arr.as_any().downcast_ref::<Int32Array>().unwrap()
471    }
472
473    fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc<IntervalMonthDayNanoArray> {
474        let data: Vec<IntervalMonthDayNano> = vals
475            .iter()
476            .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns))
477            .collect();
478        Arc::new(IntervalMonthDayNanoArray::from(data))
479    }
480
481    // --- Float64 -------------------------------------------------------------
482
483    #[test]
484    fn test_width_bucket_f64_basic() {
485        let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]);
486        let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]);
487        let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]);
488        let n = i64_array_all(5, 10);
489
490        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
491        let out = downcast_i32(&out);
492        assert_eq!(out.values(), &[1, 2, 10, 0, 11]);
493    }
494
495    #[test]
496    fn test_width_bucket_f64_descending_range() {
497        let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]);
498        let lo = f64_array(&[10.0; 5]);
499        let hi = f64_array(&[0.0; 5]);
500        let n = i64_array_all(5, 10);
501
502        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
503        let out = downcast_i32(&out);
504
505        assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
506    }
507    #[test]
508    fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() {
509        let v = f64_array(&[0.0, 9.999999999, 10.0]);
510        let lo = f64_array(&[0.0; 3]);
511        let hi = f64_array(&[10.0; 3]);
512        let n = i64_array_all(3, 10);
513
514        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
515        let out = downcast_i32(&out);
516        assert_eq!(out.values(), &[1, 10, 11]);
517    }
518
519    #[test]
520    fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() {
521        let v = f64_array(&[10.0, 0.0, -0.000001]);
522        let lo = f64_array(&[10.0; 3]);
523        let hi = f64_array(&[0.0; 3]);
524        let n = i64_array_all(3, 10);
525
526        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
527        let out = downcast_i32(&out);
528        assert_eq!(out.values(), &[1, 11, 11]);
529    }
530
531    #[test]
532    fn test_width_bucket_f64_edge_cases() {
533        let v = f64_array(&[1.0, 5.0, 9.0]);
534        let lo = f64_array(&[0.0, 0.0, 0.0]);
535        let hi = f64_array(&[10.0, 10.0, 10.0]);
536        let n = Arc::new(Int64Array::from(vec![0, -1, 10]));
537        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
538        let out = downcast_i32(&out);
539        assert!(out.is_null(0));
540        assert!(out.is_null(1));
541        assert_eq!(out.value(2), 10);
542
543        let v = f64_array(&[1.0]);
544        let lo = f64_array(&[5.0]);
545        let hi = f64_array(&[5.0]);
546        let n = i64_array_all(1, 10);
547        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
548        let out = downcast_i32(&out);
549        assert!(out.is_null(0));
550
551        let v = f64_array_opt(&[Some(f64::NAN)]);
552        let lo = f64_array(&[0.0]);
553        let hi = f64_array(&[10.0]);
554        let n = i64_array_all(1, 10);
555        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
556        let out = downcast_i32(&out);
557        assert!(out.is_null(0));
558    }
559
560    #[test]
561    fn test_width_bucket_f64_nulls_propagate() {
562        let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]);
563        let lo = f64_array(&[0.0; 4]);
564        let hi = f64_array(&[10.0; 4]);
565        let n = i64_array_all(4, 10);
566
567        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
568        let out = downcast_i32(&out);
569        assert!(out.is_null(0));
570        assert_eq!(out.value(1), 2);
571        assert_eq!(out.value(2), 3);
572        assert_eq!(out.value(3), 4);
573
574        let v = f64_array(&[1.0]);
575        let lo = f64_array_opt(&[None]);
576        let hi = f64_array(&[10.0]);
577        let n = i64_array_all(1, 10);
578        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
579        let out = downcast_i32(&out);
580        assert!(out.is_null(0));
581    }
582
583    // --- Duration(Microsecond) ----------------------------------------------
584
585    #[test]
586    fn test_width_bucket_duration_us() {
587        let v = dur_us_array(&[1_000_000, 0, -1]);
588        let lo = dur_us_array(&[0, 0, 0]);
589        let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]);
590        let n = i64_array_all(3, 2);
591
592        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
593        let out = downcast_i32(&out);
594        assert_eq!(out.values(), &[2, 1, 0]);
595    }
596
597    #[test]
598    fn test_width_bucket_duration_us_equal_bounds() {
599        let v = dur_us_array(&[0]);
600        let lo = dur_us_array(&[1]);
601        let hi = dur_us_array(&[1]);
602        let n = i64_array_all(1, 10);
603        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
604        assert!(downcast_i32(&out).is_null(0));
605    }
606
607    // --- Interval(YearMonth) ------------------------------------------------
608
609    #[test]
610    fn test_width_bucket_interval_ym_basic() {
611        let v = ym_array(&[0, 5, 11, 12, 13]);
612        let lo = ym_array(&[0; 5]);
613        let hi = ym_array(&[12; 5]);
614        let n = i64_array_all(5, 12);
615
616        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
617        let out = downcast_i32(&out);
618        assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
619    }
620
621    #[test]
622    fn test_width_bucket_interval_ym_desc() {
623        let v = ym_array(&[11, 12, 0, -1, 13]);
624        let lo = ym_array(&[12; 5]);
625        let hi = ym_array(&[0; 5]);
626        let n = i64_array_all(5, 12);
627
628        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
629        let out = downcast_i32(&out);
630        assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
631    }
632
633    // --- Interval(MonthDayNano) --------------------------------------------
634
635    #[test]
636    fn test_width_bucket_interval_mdn_months_only_basic() {
637        let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]);
638        let lo = mdn_array(&[(0, 0, 0); 5]);
639        let hi = mdn_array(&[(12, 0, 0); 5]);
640        let n = i64_array_all(5, 12);
641
642        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
643        let out = downcast_i32(&out);
644        assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
645    }
646
647    #[test]
648    fn test_width_bucket_interval_mdn_months_only_desc() {
649        let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]);
650        let lo = mdn_array(&[(12, 0, 0); 5]);
651        let hi = mdn_array(&[(0, 0, 0); 5]);
652        let n = i64_array_all(5, 12);
653
654        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
655        let out = downcast_i32(&out);
656        // Mismo patrĂ³n que YM descendente
657        assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
658    }
659
660    #[test]
661    fn test_width_bucket_interval_mdn_day_nano_basic() {
662        let v = mdn_array(&[
663            (0, 0, 0),
664            (0, 5, 0),
665            (0, 9, 0),
666            (0, 10, 0),
667            (0, -1, 0),
668            (0, 11, 0),
669        ]);
670        let lo = mdn_array(&[(0, 0, 0); 6]);
671        let hi = mdn_array(&[(0, 10, 0); 6]);
672        let n = i64_array_all(6, 10);
673
674        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
675        let out = downcast_i32(&out);
676        // x==hi -> n+1, x<lo -> 0, x>hi -> n+1
677        assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]);
678    }
679
680    #[test]
681    fn test_width_bucket_interval_mdn_day_nano_desc() {
682        let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
683        let lo = mdn_array(&[(0, 10, 0); 5]);
684        let hi = mdn_array(&[(0, 0, 0); 5]);
685        let n = i64_array_all(5, 10);
686
687        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
688        let out = downcast_i32(&out);
689
690        assert_eq!(out.values(), &[2, 1, 11, 11, 0]);
691    }
692    #[test]
693    fn test_width_bucket_interval_mdn_day_nano_desc_inside() {
694        let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
695        let lo = mdn_array(&[(0, 10, 0); 5]);
696        let hi = mdn_array(&[(0, 0, 0); 5]);
697        let n = i64_array_all(5, 10);
698
699        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
700        let out = downcast_i32(&out);
701
702        assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
703    }
704
705    #[test]
706    fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() {
707        let v = mdn_array(&[(0, 1, 0)]);
708        let lo = mdn_array(&[(0, 0, 0)]);
709        let hi = mdn_array(&[(1, 1, 0)]);
710        let n = i64_array_all(1, 4);
711
712        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
713        let out = downcast_i32(&out);
714        assert!(out.is_null(0));
715    }
716
717    #[test]
718    fn test_width_bucket_interval_mdn_equal_bounds_is_null() {
719        let v = mdn_array(&[(0, 0, 0)]);
720        let lo = mdn_array(&[(1, 2, 3)]);
721        let hi = mdn_array(&[(1, 2, 3)]); // lo == hi
722        let n = i64_array_all(1, 10);
723
724        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
725        assert!(downcast_i32(&out).is_null(0));
726    }
727
728    #[test]
729    fn test_width_bucket_interval_mdn_invalid_n_is_null() {
730        let v = mdn_array(&[(0, 0, 0)]);
731        let lo = mdn_array(&[(0, 0, 0)]);
732        let hi = mdn_array(&[(0, 10, 0)]);
733        let n = Arc::new(Int64Array::from(vec![0])); // n <= 0
734
735        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
736        assert!(downcast_i32(&out).is_null(0));
737    }
738
739    #[test]
740    fn test_width_bucket_interval_mdn_nulls_propagate() {
741        let v = Arc::new(IntervalMonthDayNanoArray::from(vec![
742            None,
743            Some(IntervalMonthDayNano::new(0, 5, 0)),
744        ]));
745        let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]);
746        let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]);
747        let n = i64_array_all(2, 10);
748
749        let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
750        let out = downcast_i32(&out);
751        assert!(out.is_null(0));
752        assert_eq!(out.value(1), 6);
753    }
754
755    // --- Errores -------------------------------------------------------------
756
757    #[test]
758    fn test_width_bucket_wrong_arg_count() {
759        let v = f64_array(&[1.0]);
760        let lo = f64_array(&[0.0]);
761        let hi = f64_array(&[10.0]);
762        let err = width_bucket_kern(&[v, lo, hi]).unwrap_err();
763        let msg = format!("{err}");
764        assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}");
765    }
766
767    #[test]
768    fn test_width_bucket_unsupported_type() {
769        let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
770        let lo = f64_array(&[0.0, 0.0, 0.0]);
771        let hi = f64_array(&[10.0, 10.0, 10.0]);
772        let n = i64_array_all(3, 10);
773
774        let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
775        let msg = format!("{err}");
776        assert!(
777            msg.contains("width_bucket received unexpected data types"),
778            "unexpected error: {msg}"
779        );
780    }
781}