1use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::unsupported_data_types_exec_err;
22use arrow::array::{
23 Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
24 IntervalYearMonthArray,
25};
26use arrow::datatypes::DataType;
27use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval};
28use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth};
29use datafusion_common::cast::{
30 as_duration_microsecond_array, as_float64_array, as_int32_array,
31 as_interval_mdn_array, as_interval_ym_array,
32};
33use datafusion_common::{exec_err, Result};
34use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
36use datafusion_functions::utils::make_scalar_function;
37
38use arrow::array::{Int32Array, Int32Builder};
39use arrow::datatypes::TimeUnit::Microsecond;
40use datafusion_expr::Volatility::Immutable;
41
42#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkWidthBucket {
44 signature: Signature,
45}
46
47impl Default for SparkWidthBucket {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkWidthBucket {
54 pub fn new() -> Self {
55 Self {
56 signature: Signature::user_defined(Immutable),
57 }
58 }
59}
60
61impl ScalarUDFImpl for SparkWidthBucket {
62 fn as_any(&self) -> &dyn Any {
63 self
64 }
65
66 fn name(&self) -> &str {
67 "width_bucket"
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75 Ok(Int32)
76 }
77
78 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
79 make_scalar_function(width_bucket_kern, vec![])(&args.args)
80 }
81
82 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
83 if input.len() == 1 {
84 let value = &input[0];
85 Ok(value.sort_properties)
86 } else {
87 Ok(SortProperties::default())
88 }
89 }
90
91 fn coerce_types(&self, types: &[DataType]) -> Result<Vec<DataType>> {
92 use DataType::*;
93
94 let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]);
95
96 let is_num = |t: &DataType| {
97 matches!(
98 t,
99 Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
100 )
101 };
102
103 match (v, lo, hi, n) {
104 (a, b, c, &(Int8 | Int16 | Int32 | Int64))
105 if is_num(a) && is_num(b) && is_num(c) =>
106 {
107 Ok(vec![Float64, Float64, Float64, Int32])
108 }
109 (
110 &Duration(_),
111 &Duration(_),
112 &Duration(_),
113 &(Int8 | Int16 | Int32 | Int64),
114 ) => Ok(vec![
115 Duration(Microsecond),
116 Duration(Microsecond),
117 Duration(Microsecond),
118 Int32,
119 ]),
120 (
121 &Interval(MonthDayNano),
122 &Interval(MonthDayNano),
123 &Interval(MonthDayNano),
124 &(Int8 | Int16 | Int32 | Int64),
125 ) => Ok(vec![
126 Interval(MonthDayNano),
127 Interval(MonthDayNano),
128 Interval(MonthDayNano),
129 Int32,
130 ]),
131 (
132 &Interval(YearMonth),
133 &Interval(YearMonth),
134 &Interval(YearMonth),
135 &(Int8 | Int16 | Int32 | Int64),
136 ) => Ok(vec![
137 Interval(YearMonth),
138 Interval(YearMonth),
139 Interval(YearMonth),
140 Int32,
141 ]),
142
143 _ => exec_err!(
144 "width_bucket expects a numeric argument, got {} {} {} {}",
145 types[0],
146 types[1],
147 types[2],
148 types[3]
149 ),
150 }
151 }
152}
153
154fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
155 let [v, minv, maxv, nb] = args else {
156 return exec_err!(
157 "width_bucket expects exactly 4 argument, got {}",
158 args.len()
159 );
160 };
161
162 match v.data_type() {
163 Float64 => {
164 let v = as_float64_array(v)?;
165 let min = as_float64_array(minv)?;
166 let max = as_float64_array(maxv)?;
167 let n_bucket = as_int32_array(nb)?;
168 Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket)))
169 }
170 Duration(Microsecond) => {
171 let v = as_duration_microsecond_array(v)?;
172 let min = as_duration_microsecond_array(minv)?;
173 let max = as_duration_microsecond_array(maxv)?;
174 let n_bucket = as_int32_array(nb)?;
175 Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket)))
176 }
177 Interval(YearMonth) => {
178 let v = as_interval_ym_array(v)?;
179 let min = as_interval_ym_array(minv)?;
180 let max = as_interval_ym_array(maxv)?;
181 let n_bucket = as_int32_array(nb)?;
182 Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket)))
183 }
184 Interval(MonthDayNano) => {
185 let v = as_interval_mdn_array(v)?;
186 let min = as_interval_mdn_array(minv)?;
187 let max = as_interval_mdn_array(maxv)?;
188 let n_bucket = as_int32_array(nb)?;
189 Ok(Arc::new(width_bucket_interval_mdn_exact(v, min, max, n_bucket)))
190 }
191
192
193 other => Err(unsupported_data_types_exec_err(
194 "width_bucket",
195 "Float/Decimal OR Duration OR Interval(YearMonth) for first 3 args; Int for 4th",
196 &[
197 other.clone(),
198 minv.data_type().clone(),
199 maxv.data_type().clone(),
200 nb.data_type().clone(),
201 ],
202 )),
203 }
204}
205
206macro_rules! width_bucket_kernel_impl {
207 ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => {
208 pub(crate) fn $name(
209 v: &$arr_ty,
210 min: &$arr_ty,
211 max: &$arr_ty,
212 n_bucket: &Int32Array,
213 ) -> Int32Array {
214 let len = v.len();
215 let mut b = Int32Builder::with_capacity(len);
216
217 for i in 0..len {
218 if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i)
219 {
220 b.append_null();
221 continue;
222 }
223 let x = ($to_f64)(v, i);
224 let l = ($to_f64)(min, i);
225 let h = ($to_f64)(max, i);
226 let buckets = n_bucket.value(i);
227
228 if buckets <= 0 {
229 b.append_null();
230 continue;
231 }
232 if $check_nan {
233 if !x.is_finite() || !l.is_finite() || !h.is_finite() {
234 b.append_null();
235 continue;
236 }
237 }
238
239 let ord = match l.partial_cmp(&h) {
240 Some(o) => o,
241 None => {
242 b.append_null();
243 continue;
244 }
245 };
246 if matches!(ord, std::cmp::Ordering::Equal) {
247 b.append_null();
248 continue;
249 }
250 let asc = matches!(ord, std::cmp::Ordering::Less);
251
252 if asc {
253 if x < l {
254 b.append_value(0);
255 continue;
256 }
257 if x >= h {
258 b.append_value(buckets + 1);
259 continue;
260 }
261 } else {
262 if x > l {
263 b.append_value(0);
264 continue;
265 }
266 if x <= h {
267 b.append_value(buckets + 1);
268 continue;
269 }
270 }
271
272 let width = (h - l) / (buckets as f64);
273 if width == 0.0 || !width.is_finite() {
274 b.append_null();
275 continue;
276 }
277 let mut bucket = ((x - l) / width).floor() as i32 + 1;
278 if bucket < 1 {
279 bucket = 1;
280 }
281 if bucket > buckets + 1 {
282 bucket = buckets + 1;
283 }
284
285 b.append_value(bucket);
286 }
287
288 b.finish()
289 }
290 };
291}
292
293width_bucket_kernel_impl!(
294 width_bucket_float64,
295 Float64Array,
296 |arr: &Float64Array, i: usize| arr.value(i),
297 true
298);
299
300width_bucket_kernel_impl!(
301 width_bucket_i64_as_float,
302 DurationMicrosecondArray,
303 |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64,
304 false
305);
306
307width_bucket_kernel_impl!(
308 width_bucket_i32_as_float,
309 IntervalYearMonthArray,
310 |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64,
311 false
312);
313const NS_PER_DAY_I128: i128 = 86_400_000_000_000;
314pub(crate) fn width_bucket_interval_mdn_exact(
315 v: &IntervalMonthDayNanoArray,
316 lo: &IntervalMonthDayNanoArray,
317 hi: &IntervalMonthDayNanoArray,
318 n: &Int32Array,
319) -> Int32Array {
320 let len = v.len();
321 let mut b = Int32Builder::with_capacity(len);
322
323 for i in 0..len {
324 if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) {
325 b.append_null();
326 continue;
327 }
328 let buckets = n.value(i);
329 if buckets <= 0 {
330 b.append_null();
331 continue;
332 }
333
334 let x = v.value(i);
335 let l = lo.value(i);
336 let h = hi.value(i);
337
338 let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds);
341 if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) {
342 b.append_null();
343 continue;
344 }
345
346 if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months {
348 let x_m = x.months as f64;
349 let l_m = l.months as f64;
350 let h_m = h.months as f64;
351
352 if asc {
353 if x_m < l_m {
354 b.append_value(0);
355 continue;
356 }
357 if x_m >= h_m {
358 b.append_value(buckets + 1);
359 continue;
360 }
361 } else {
362 if x_m > l_m {
363 b.append_value(0);
364 continue;
365 }
366 if x_m <= h_m {
367 b.append_value(buckets + 1);
368 continue;
369 }
370 }
371
372 let width = (h_m - l_m) / (buckets as f64);
373 if width == 0.0 || !width.is_finite() {
374 b.append_null();
375 continue;
376 }
377
378 let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1;
379 if bucket < 1 {
380 bucket = 1;
381 }
382 if bucket > buckets + 1 {
383 bucket = buckets + 1;
384 }
385 b.append_value(bucket);
386 continue;
387 }
388
389 if l.months == h.months {
391 let base_days = l.days as i128;
392 let base_ns = l.nanoseconds as i128;
393
394 let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128
395 + (x.nanoseconds as i128 - base_ns);
396 let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128
397 + (h.nanoseconds as i128 - base_ns);
398
399 let x_f = xf as f64;
400 let l_f = 0.0;
401 let h_f = hf as f64;
402
403 if asc {
404 if x_f < l_f {
405 b.append_value(0);
406 continue;
407 }
408 if x_f >= h_f {
409 b.append_value(buckets + 1);
410 continue;
411 }
412 } else {
413 if x_f > l_f {
414 b.append_value(0);
415 continue;
416 }
417 if x_f <= h_f {
418 b.append_value(buckets + 1);
419 continue;
420 }
421 }
422
423 let width = (h_f - l_f) / (buckets as f64);
424 if width == 0.0 || !width.is_finite() {
425 b.append_null();
426 continue;
427 }
428
429 let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1;
430 if bucket < 1 {
431 bucket = 1;
432 }
433 if bucket > buckets + 1 {
434 bucket = buckets + 1;
435 }
436 b.append_value(bucket);
437 continue;
438 }
439
440 b.append_null();
441 }
442
443 b.finish()
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use std::sync::Arc;
450
451 use arrow::array::{
452 ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array,
453 IntervalYearMonthArray,
454 };
455 use arrow::datatypes::IntervalMonthDayNano;
456
457 fn i32_array_all(len: usize, val: i32) -> Arc<Int32Array> {
460 Arc::new(Int32Array::from(vec![val; len]))
461 }
462
463 fn f64_array(vals: &[f64]) -> Arc<Float64Array> {
464 Arc::new(Float64Array::from(vals.to_vec()))
465 }
466
467 fn f64_array_opt(vals: &[Option<f64>]) -> Arc<Float64Array> {
468 Arc::new(Float64Array::from(vals.to_vec()))
469 }
470
471 fn dur_us_array(vals: &[i64]) -> Arc<DurationMicrosecondArray> {
472 Arc::new(DurationMicrosecondArray::from(vals.to_vec()))
473 }
474
475 fn ym_array(vals: &[i32]) -> Arc<IntervalYearMonthArray> {
476 Arc::new(IntervalYearMonthArray::from(vals.to_vec()))
477 }
478
479 fn downcast_i32(arr: &ArrayRef) -> &Int32Array {
480 arr.as_any().downcast_ref::<Int32Array>().unwrap()
481 }
482
483 fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc<IntervalMonthDayNanoArray> {
484 let data: Vec<IntervalMonthDayNano> = vals
485 .iter()
486 .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns))
487 .collect();
488 Arc::new(IntervalMonthDayNanoArray::from(data))
489 }
490
491 #[test]
494 fn test_width_bucket_f64_basic() {
495 let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]);
496 let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]);
497 let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]);
498 let n = i32_array_all(5, 10);
499
500 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
501 let out = downcast_i32(&out);
502 assert_eq!(out.values(), &[1, 2, 10, 0, 11]);
503 }
504
505 #[test]
506 fn test_width_bucket_f64_descending_range() {
507 let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]);
508 let lo = f64_array(&[10.0; 5]);
509 let hi = f64_array(&[0.0; 5]);
510 let n = i32_array_all(5, 10);
511
512 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
513 let out = downcast_i32(&out);
514
515 assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
516 }
517 #[test]
518 fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() {
519 let v = f64_array(&[0.0, 9.999999999, 10.0]);
520 let lo = f64_array(&[0.0; 3]);
521 let hi = f64_array(&[10.0; 3]);
522 let n = i32_array_all(3, 10);
523
524 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
525 let out = downcast_i32(&out);
526 assert_eq!(out.values(), &[1, 10, 11]);
527 }
528
529 #[test]
530 fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() {
531 let v = f64_array(&[10.0, 0.0, -0.000001]);
532 let lo = f64_array(&[10.0; 3]);
533 let hi = f64_array(&[0.0; 3]);
534 let n = i32_array_all(3, 10);
535
536 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
537 let out = downcast_i32(&out);
538 assert_eq!(out.values(), &[1, 11, 11]);
539 }
540
541 #[test]
542 fn test_width_bucket_f64_edge_cases() {
543 let v = f64_array(&[1.0, 5.0, 9.0]);
544 let lo = f64_array(&[0.0, 0.0, 0.0]);
545 let hi = f64_array(&[10.0, 10.0, 10.0]);
546 let n = Arc::new(Int32Array::from(vec![0, -1, 10]));
547 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
548 let out = downcast_i32(&out);
549 assert!(out.is_null(0));
550 assert!(out.is_null(1));
551 assert_eq!(out.value(2), 10);
552
553 let v = f64_array(&[1.0]);
554 let lo = f64_array(&[5.0]);
555 let hi = f64_array(&[5.0]);
556 let n = i32_array_all(1, 10);
557 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
558 let out = downcast_i32(&out);
559 assert!(out.is_null(0));
560
561 let v = f64_array_opt(&[Some(f64::NAN)]);
562 let lo = f64_array(&[0.0]);
563 let hi = f64_array(&[10.0]);
564 let n = i32_array_all(1, 10);
565 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
566 let out = downcast_i32(&out);
567 assert!(out.is_null(0));
568 }
569
570 #[test]
571 fn test_width_bucket_f64_nulls_propagate() {
572 let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]);
573 let lo = f64_array(&[0.0; 4]);
574 let hi = f64_array(&[10.0; 4]);
575 let n = i32_array_all(4, 10);
576
577 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
578 let out = downcast_i32(&out);
579 assert!(out.is_null(0));
580 assert_eq!(out.value(1), 2);
581 assert_eq!(out.value(2), 3);
582 assert_eq!(out.value(3), 4);
583
584 let v = f64_array(&[1.0]);
585 let lo = f64_array_opt(&[None]);
586 let hi = f64_array(&[10.0]);
587 let n = i32_array_all(1, 10);
588 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
589 let out = downcast_i32(&out);
590 assert!(out.is_null(0));
591 }
592
593 #[test]
596 fn test_width_bucket_duration_us() {
597 let v = dur_us_array(&[1_000_000, 0, -1]);
598 let lo = dur_us_array(&[0, 0, 0]);
599 let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]);
600 let n = i32_array_all(3, 2);
601
602 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
603 let out = downcast_i32(&out);
604 assert_eq!(out.values(), &[2, 1, 0]);
605 }
606
607 #[test]
608 fn test_width_bucket_duration_us_equal_bounds() {
609 let v = dur_us_array(&[0]);
610 let lo = dur_us_array(&[1]);
611 let hi = dur_us_array(&[1]);
612 let n = i32_array_all(1, 10);
613 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
614 assert!(downcast_i32(&out).is_null(0));
615 }
616
617 #[test]
620 fn test_width_bucket_interval_ym_basic() {
621 let v = ym_array(&[0, 5, 11, 12, 13]);
622 let lo = ym_array(&[0; 5]);
623 let hi = ym_array(&[12; 5]);
624 let n = i32_array_all(5, 12);
625
626 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
627 let out = downcast_i32(&out);
628 assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
629 }
630
631 #[test]
632 fn test_width_bucket_interval_ym_desc() {
633 let v = ym_array(&[11, 12, 0, -1, 13]);
634 let lo = ym_array(&[12; 5]);
635 let hi = ym_array(&[0; 5]);
636 let n = i32_array_all(5, 12);
637
638 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
639 let out = downcast_i32(&out);
640 assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
641 }
642
643 #[test]
646 fn test_width_bucket_interval_mdn_months_only_basic() {
647 let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]);
648 let lo = mdn_array(&[(0, 0, 0); 5]);
649 let hi = mdn_array(&[(12, 0, 0); 5]);
650 let n = i32_array_all(5, 12);
651
652 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
653 let out = downcast_i32(&out);
654 assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
655 }
656
657 #[test]
658 fn test_width_bucket_interval_mdn_months_only_desc() {
659 let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]);
660 let lo = mdn_array(&[(12, 0, 0); 5]);
661 let hi = mdn_array(&[(0, 0, 0); 5]);
662 let n = i32_array_all(5, 12);
663
664 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
665 let out = downcast_i32(&out);
666 assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
668 }
669
670 #[test]
671 fn test_width_bucket_interval_mdn_day_nano_basic() {
672 let v = mdn_array(&[
673 (0, 0, 0),
674 (0, 5, 0),
675 (0, 9, 0),
676 (0, 10, 0),
677 (0, -1, 0),
678 (0, 11, 0),
679 ]);
680 let lo = mdn_array(&[(0, 0, 0); 6]);
681 let hi = mdn_array(&[(0, 10, 0); 6]);
682 let n = i32_array_all(6, 10);
683
684 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
685 let out = downcast_i32(&out);
686 assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]);
688 }
689
690 #[test]
691 fn test_width_bucket_interval_mdn_day_nano_desc() {
692 let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
693 let lo = mdn_array(&[(0, 10, 0); 5]);
694 let hi = mdn_array(&[(0, 0, 0); 5]);
695 let n = i32_array_all(5, 10);
696
697 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
698 let out = downcast_i32(&out);
699
700 assert_eq!(out.values(), &[2, 1, 11, 11, 0]);
701 }
702 #[test]
703 fn test_width_bucket_interval_mdn_day_nano_desc_inside() {
704 let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
705 let lo = mdn_array(&[(0, 10, 0); 5]);
706 let hi = mdn_array(&[(0, 0, 0); 5]);
707 let n = i32_array_all(5, 10);
708
709 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
710 let out = downcast_i32(&out);
711
712 assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
713 }
714
715 #[test]
716 fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() {
717 let v = mdn_array(&[(0, 1, 0)]);
718 let lo = mdn_array(&[(0, 0, 0)]);
719 let hi = mdn_array(&[(1, 1, 0)]);
720 let n = i32_array_all(1, 4);
721
722 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
723 let out = downcast_i32(&out);
724 assert!(out.is_null(0));
725 }
726
727 #[test]
728 fn test_width_bucket_interval_mdn_equal_bounds_is_null() {
729 let v = mdn_array(&[(0, 0, 0)]);
730 let lo = mdn_array(&[(1, 2, 3)]);
731 let hi = mdn_array(&[(1, 2, 3)]); let n = i32_array_all(1, 10);
733
734 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
735 assert!(downcast_i32(&out).is_null(0));
736 }
737
738 #[test]
739 fn test_width_bucket_interval_mdn_invalid_n_is_null() {
740 let v = mdn_array(&[(0, 0, 0)]);
741 let lo = mdn_array(&[(0, 0, 0)]);
742 let hi = mdn_array(&[(0, 10, 0)]);
743 let n = Arc::new(Int32Array::from(vec![0])); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
746 assert!(downcast_i32(&out).is_null(0));
747 }
748
749 #[test]
750 fn test_width_bucket_interval_mdn_nulls_propagate() {
751 let v = Arc::new(IntervalMonthDayNanoArray::from(vec![
752 None,
753 Some(IntervalMonthDayNano::new(0, 5, 0)),
754 ]));
755 let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]);
756 let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]);
757 let n = i32_array_all(2, 10);
758
759 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
760 let out = downcast_i32(&out);
761 assert!(out.is_null(0));
762 assert_eq!(out.value(1), 6);
763 }
764
765 #[test]
768 fn test_width_bucket_wrong_arg_count() {
769 let v = f64_array(&[1.0]);
770 let lo = f64_array(&[0.0]);
771 let hi = f64_array(&[10.0]);
772 let err = width_bucket_kern(&[v, lo, hi]).unwrap_err();
773 let msg = format!("{err}");
774 assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}");
775 }
776
777 #[test]
778 fn test_width_bucket_unsupported_type() {
779 let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
780 let lo = f64_array(&[0.0, 0.0, 0.0]);
781 let hi = f64_array(&[10.0, 10.0, 10.0]);
782 let n = i32_array_all(3, 10);
783
784 let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
785 let msg = format!("{err}");
786 assert!(
787 msg.contains("unsupported data types")
788 || msg.contains("Float/Decimal OR Duration OR Interval(YearMonth)"),
789 "unexpected error: {msg}"
790 );
791 }
792}