1use std::sync::Arc;
19
20use arrow::array::{
21 Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
22 IntervalYearMonthArray,
23};
24use arrow::datatypes::DataType;
25use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval};
26use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth};
27use datafusion_common::cast::{
28 as_duration_microsecond_array, as_float64_array, as_int64_array,
29 as_interval_mdn_array, as_interval_ym_array,
30};
31use datafusion_common::types::{
32 NativeType, logical_duration_microsecond, logical_float64, logical_int64,
33 logical_interval_mdn, logical_interval_year_month,
34};
35use datafusion_common::{Result, exec_err, internal_err};
36use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
37use datafusion_expr::{
38 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
39 TypeSignatureClass,
40};
41use datafusion_functions::utils::make_scalar_function;
42
43use arrow::array::{Int32Array, Int32Builder, Int64Array};
44use arrow::datatypes::TimeUnit::Microsecond;
45use datafusion_expr::Coercion;
46use datafusion_expr::Volatility::Immutable;
47
48#[derive(Debug, PartialEq, Eq, Hash)]
49pub struct SparkWidthBucket {
50 signature: Signature,
51}
52
53impl Default for SparkWidthBucket {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl SparkWidthBucket {
60 pub fn new() -> Self {
61 let numeric = Coercion::new_implicit(
62 TypeSignatureClass::Native(logical_float64()),
63 vec![TypeSignatureClass::Numeric],
64 NativeType::Float64,
65 );
66 let duration = Coercion::new_implicit(
67 TypeSignatureClass::Native(logical_duration_microsecond()),
68 vec![TypeSignatureClass::Duration],
69 NativeType::Duration(Microsecond),
70 );
71 let interval_ym = Coercion::new_exact(TypeSignatureClass::Native(
72 logical_interval_year_month(),
73 ));
74 let interval_mdn =
75 Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn()));
76 let bucket = Coercion::new_implicit(
77 TypeSignatureClass::Native(logical_int64()),
78 vec![TypeSignatureClass::Integer],
79 NativeType::Int64,
80 );
81 let type_signature = Signature::one_of(
82 vec![
83 TypeSignature::Coercible(vec![
84 numeric.clone(),
85 numeric.clone(),
86 numeric.clone(),
87 bucket.clone(),
88 ]),
89 TypeSignature::Coercible(vec![
90 duration.clone(),
91 duration.clone(),
92 duration.clone(),
93 bucket.clone(),
94 ]),
95 TypeSignature::Coercible(vec![
96 interval_ym.clone(),
97 interval_ym.clone(),
98 interval_ym.clone(),
99 bucket.clone(),
100 ]),
101 TypeSignature::Coercible(vec![
102 interval_mdn.clone(),
103 interval_mdn.clone(),
104 interval_mdn.clone(),
105 bucket.clone(),
106 ]),
107 ],
108 Immutable,
109 )
110 .with_parameter_names(vec!["expr", "min", "max", "num_buckets"])
111 .expect("valid parameter names");
112 Self {
113 signature: type_signature,
114 }
115 }
116}
117
118impl ScalarUDFImpl for SparkWidthBucket {
119 fn name(&self) -> &str {
120 "width_bucket"
121 }
122
123 fn signature(&self) -> &Signature {
124 &self.signature
125 }
126
127 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
128 Ok(Int32)
129 }
130
131 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
132 make_scalar_function(width_bucket_kern, vec![])(&args.args)
133 }
134
135 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
136 if input.len() == 1 {
137 let value = &input[0];
138 Ok(value.sort_properties)
139 } else {
140 Ok(SortProperties::default())
141 }
142 }
143}
144
145fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
146 let [v, minv, maxv, nb] = args else {
147 return exec_err!(
148 "width_bucket expects exactly 4 argument, got {}",
149 args.len()
150 );
151 };
152
153 match v.data_type() {
154 Float64 => {
155 let v = as_float64_array(v)?;
156 let min = as_float64_array(minv)?;
157 let max = as_float64_array(maxv)?;
158 let n_bucket = as_int64_array(nb)?;
159 Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket)))
160 }
161 Duration(Microsecond) => {
162 let v = as_duration_microsecond_array(v)?;
163 let min = as_duration_microsecond_array(minv)?;
164 let max = as_duration_microsecond_array(maxv)?;
165 let n_bucket = as_int64_array(nb)?;
166 Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket)))
167 }
168 Interval(YearMonth) => {
169 let v = as_interval_ym_array(v)?;
170 let min = as_interval_ym_array(minv)?;
171 let max = as_interval_ym_array(maxv)?;
172 let n_bucket = as_int64_array(nb)?;
173 Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket)))
174 }
175 Interval(MonthDayNano) => {
176 let v = as_interval_mdn_array(v)?;
177 let min = as_interval_mdn_array(minv)?;
178 let max = as_interval_mdn_array(maxv)?;
179 let n_bucket = as_int64_array(nb)?;
180 Ok(Arc::new(width_bucket_interval_mdn_exact(
181 v, min, max, n_bucket,
182 )))
183 }
184
185 other => internal_err!(
186 "width_bucket received unexpected data types: {:?}, {:?}, {:?}, {:?}",
187 other,
188 minv.data_type(),
189 maxv.data_type(),
190 nb.data_type()
191 ),
192 }
193}
194
195macro_rules! width_bucket_kernel_impl {
196 ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => {
197 pub(crate) fn $name(
198 v: &$arr_ty,
199 min: &$arr_ty,
200 max: &$arr_ty,
201 n_bucket: &Int64Array,
202 ) -> Int32Array {
203 let len = v.len();
204 let mut b = Int32Builder::with_capacity(len);
205
206 for i in 0..len {
207 if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i)
208 {
209 b.append_null();
210 continue;
211 }
212 let x = ($to_f64)(v, i);
213 let l = ($to_f64)(min, i);
214 let h = ($to_f64)(max, i);
215 let buckets = n_bucket.value(i);
216
217 if buckets <= 0 {
218 b.append_null();
219 continue;
220 }
221 let next_bucket = (buckets + 1) as i32;
222 if $check_nan {
223 if !x.is_finite() || !l.is_finite() || !h.is_finite() {
224 b.append_null();
225 continue;
226 }
227 }
228
229 let ord = match l.partial_cmp(&h) {
230 Some(o) => o,
231 None => {
232 b.append_null();
233 continue;
234 }
235 };
236 if ord == std::cmp::Ordering::Equal {
237 b.append_null();
238 continue;
239 }
240 let asc = ord == std::cmp::Ordering::Less;
241
242 if asc {
243 if x < l {
244 b.append_value(0);
245 continue;
246 }
247 if x >= h {
248 b.append_value(next_bucket);
249 continue;
250 }
251 } else {
252 if x > l {
253 b.append_value(0);
254 continue;
255 }
256 if x <= h {
257 b.append_value(next_bucket);
258 continue;
259 }
260 }
261
262 let width = (h - l) / (buckets as f64);
263 if width == 0.0 || !width.is_finite() {
264 b.append_null();
265 continue;
266 }
267 let mut bucket = ((x - l) / width).floor() as i32 + 1;
268 if bucket < 1 {
269 bucket = 1;
270 }
271 if bucket > next_bucket {
272 bucket = next_bucket;
273 }
274
275 b.append_value(bucket);
276 }
277
278 b.finish()
279 }
280 };
281}
282
283width_bucket_kernel_impl!(
284 width_bucket_float64,
285 Float64Array,
286 |arr: &Float64Array, i: usize| arr.value(i),
287 true
288);
289
290width_bucket_kernel_impl!(
291 width_bucket_i64_as_float,
292 DurationMicrosecondArray,
293 |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64,
294 false
295);
296
297width_bucket_kernel_impl!(
298 width_bucket_i32_as_float,
299 IntervalYearMonthArray,
300 |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64,
301 false
302);
303const NS_PER_DAY_I128: i128 = 86_400_000_000_000;
304pub(crate) fn width_bucket_interval_mdn_exact(
305 v: &IntervalMonthDayNanoArray,
306 lo: &IntervalMonthDayNanoArray,
307 hi: &IntervalMonthDayNanoArray,
308 n: &Int64Array,
309) -> Int32Array {
310 let len = v.len();
311 let mut b = Int32Builder::with_capacity(len);
312
313 for i in 0..len {
314 if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) {
315 b.append_null();
316 continue;
317 }
318 let buckets = n.value(i);
319 if buckets <= 0 {
320 b.append_null();
321 continue;
322 }
323 let next_bucket = (buckets + 1) as i32;
324
325 let x = v.value(i);
326 let l = lo.value(i);
327 let h = hi.value(i);
328
329 let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds);
332 if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) {
333 b.append_null();
334 continue;
335 }
336
337 if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months {
339 let x_m = x.months as f64;
340 let l_m = l.months as f64;
341 let h_m = h.months as f64;
342
343 if asc {
344 if x_m < l_m {
345 b.append_value(0);
346 continue;
347 }
348 if x_m >= h_m {
349 b.append_value(next_bucket);
350 continue;
351 }
352 } else {
353 if x_m > l_m {
354 b.append_value(0);
355 continue;
356 }
357 if x_m <= h_m {
358 b.append_value(next_bucket);
359 continue;
360 }
361 }
362
363 let width = (h_m - l_m) / (buckets as f64);
364 if width == 0.0 || !width.is_finite() {
365 b.append_null();
366 continue;
367 }
368
369 let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1;
370 if bucket < 1 {
371 bucket = 1;
372 }
373 if bucket > next_bucket {
374 bucket = next_bucket;
375 }
376 b.append_value(bucket);
377 continue;
378 }
379
380 if l.months == h.months {
382 let base_days = l.days as i128;
383 let base_ns = l.nanoseconds as i128;
384
385 let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128
386 + (x.nanoseconds as i128 - base_ns);
387 let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128
388 + (h.nanoseconds as i128 - base_ns);
389
390 let x_f = xf as f64;
391 let l_f = 0.0;
392 let h_f = hf as f64;
393
394 if asc {
395 if x_f < l_f {
396 b.append_value(0);
397 continue;
398 }
399 if x_f >= h_f {
400 b.append_value(next_bucket);
401 continue;
402 }
403 } else {
404 if x_f > l_f {
405 b.append_value(0);
406 continue;
407 }
408 if x_f <= h_f {
409 b.append_value(next_bucket);
410 continue;
411 }
412 }
413
414 let width = (h_f - l_f) / (buckets as f64);
415 if width == 0.0 || !width.is_finite() {
416 b.append_null();
417 continue;
418 }
419
420 let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1;
421 if bucket < 1 {
422 bucket = 1;
423 }
424 if bucket > next_bucket {
425 bucket = next_bucket;
426 }
427 b.append_value(bucket);
428 continue;
429 }
430
431 b.append_null();
432 }
433
434 b.finish()
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 use arrow::array::{
442 ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, Int64Array,
443 IntervalYearMonthArray,
444 };
445 use arrow::datatypes::IntervalMonthDayNano;
446
447 fn i64_array_all(len: usize, val: i64) -> Arc<Int64Array> {
450 Arc::new(Int64Array::from(vec![val; len]))
451 }
452
453 fn f64_array(vals: &[f64]) -> Arc<Float64Array> {
454 Arc::new(Float64Array::from(vals.to_vec()))
455 }
456
457 fn f64_array_opt(vals: &[Option<f64>]) -> Arc<Float64Array> {
458 Arc::new(Float64Array::from(vals.to_vec()))
459 }
460
461 fn dur_us_array(vals: &[i64]) -> Arc<DurationMicrosecondArray> {
462 Arc::new(DurationMicrosecondArray::from(vals.to_vec()))
463 }
464
465 fn ym_array(vals: &[i32]) -> Arc<IntervalYearMonthArray> {
466 Arc::new(IntervalYearMonthArray::from(vals.to_vec()))
467 }
468
469 fn downcast_i32(arr: &ArrayRef) -> &Int32Array {
470 arr.as_any().downcast_ref::<Int32Array>().unwrap()
471 }
472
473 fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc<IntervalMonthDayNanoArray> {
474 let data: Vec<IntervalMonthDayNano> = vals
475 .iter()
476 .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns))
477 .collect();
478 Arc::new(IntervalMonthDayNanoArray::from(data))
479 }
480
481 #[test]
484 fn test_width_bucket_f64_basic() {
485 let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]);
486 let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]);
487 let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]);
488 let n = i64_array_all(5, 10);
489
490 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
491 let out = downcast_i32(&out);
492 assert_eq!(out.values(), &[1, 2, 10, 0, 11]);
493 }
494
495 #[test]
496 fn test_width_bucket_f64_descending_range() {
497 let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]);
498 let lo = f64_array(&[10.0; 5]);
499 let hi = f64_array(&[0.0; 5]);
500 let n = i64_array_all(5, 10);
501
502 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
503 let out = downcast_i32(&out);
504
505 assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
506 }
507 #[test]
508 fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() {
509 let v = f64_array(&[0.0, 9.999999999, 10.0]);
510 let lo = f64_array(&[0.0; 3]);
511 let hi = f64_array(&[10.0; 3]);
512 let n = i64_array_all(3, 10);
513
514 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
515 let out = downcast_i32(&out);
516 assert_eq!(out.values(), &[1, 10, 11]);
517 }
518
519 #[test]
520 fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() {
521 let v = f64_array(&[10.0, 0.0, -0.000001]);
522 let lo = f64_array(&[10.0; 3]);
523 let hi = f64_array(&[0.0; 3]);
524 let n = i64_array_all(3, 10);
525
526 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
527 let out = downcast_i32(&out);
528 assert_eq!(out.values(), &[1, 11, 11]);
529 }
530
531 #[test]
532 fn test_width_bucket_f64_edge_cases() {
533 let v = f64_array(&[1.0, 5.0, 9.0]);
534 let lo = f64_array(&[0.0, 0.0, 0.0]);
535 let hi = f64_array(&[10.0, 10.0, 10.0]);
536 let n = Arc::new(Int64Array::from(vec![0, -1, 10]));
537 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
538 let out = downcast_i32(&out);
539 assert!(out.is_null(0));
540 assert!(out.is_null(1));
541 assert_eq!(out.value(2), 10);
542
543 let v = f64_array(&[1.0]);
544 let lo = f64_array(&[5.0]);
545 let hi = f64_array(&[5.0]);
546 let n = i64_array_all(1, 10);
547 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
548 let out = downcast_i32(&out);
549 assert!(out.is_null(0));
550
551 let v = f64_array_opt(&[Some(f64::NAN)]);
552 let lo = f64_array(&[0.0]);
553 let hi = f64_array(&[10.0]);
554 let n = i64_array_all(1, 10);
555 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
556 let out = downcast_i32(&out);
557 assert!(out.is_null(0));
558 }
559
560 #[test]
561 fn test_width_bucket_f64_nulls_propagate() {
562 let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]);
563 let lo = f64_array(&[0.0; 4]);
564 let hi = f64_array(&[10.0; 4]);
565 let n = i64_array_all(4, 10);
566
567 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
568 let out = downcast_i32(&out);
569 assert!(out.is_null(0));
570 assert_eq!(out.value(1), 2);
571 assert_eq!(out.value(2), 3);
572 assert_eq!(out.value(3), 4);
573
574 let v = f64_array(&[1.0]);
575 let lo = f64_array_opt(&[None]);
576 let hi = f64_array(&[10.0]);
577 let n = i64_array_all(1, 10);
578 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
579 let out = downcast_i32(&out);
580 assert!(out.is_null(0));
581 }
582
583 #[test]
586 fn test_width_bucket_duration_us() {
587 let v = dur_us_array(&[1_000_000, 0, -1]);
588 let lo = dur_us_array(&[0, 0, 0]);
589 let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]);
590 let n = i64_array_all(3, 2);
591
592 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
593 let out = downcast_i32(&out);
594 assert_eq!(out.values(), &[2, 1, 0]);
595 }
596
597 #[test]
598 fn test_width_bucket_duration_us_equal_bounds() {
599 let v = dur_us_array(&[0]);
600 let lo = dur_us_array(&[1]);
601 let hi = dur_us_array(&[1]);
602 let n = i64_array_all(1, 10);
603 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
604 assert!(downcast_i32(&out).is_null(0));
605 }
606
607 #[test]
610 fn test_width_bucket_interval_ym_basic() {
611 let v = ym_array(&[0, 5, 11, 12, 13]);
612 let lo = ym_array(&[0; 5]);
613 let hi = ym_array(&[12; 5]);
614 let n = i64_array_all(5, 12);
615
616 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
617 let out = downcast_i32(&out);
618 assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
619 }
620
621 #[test]
622 fn test_width_bucket_interval_ym_desc() {
623 let v = ym_array(&[11, 12, 0, -1, 13]);
624 let lo = ym_array(&[12; 5]);
625 let hi = ym_array(&[0; 5]);
626 let n = i64_array_all(5, 12);
627
628 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
629 let out = downcast_i32(&out);
630 assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
631 }
632
633 #[test]
636 fn test_width_bucket_interval_mdn_months_only_basic() {
637 let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]);
638 let lo = mdn_array(&[(0, 0, 0); 5]);
639 let hi = mdn_array(&[(12, 0, 0); 5]);
640 let n = i64_array_all(5, 12);
641
642 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
643 let out = downcast_i32(&out);
644 assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
645 }
646
647 #[test]
648 fn test_width_bucket_interval_mdn_months_only_desc() {
649 let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]);
650 let lo = mdn_array(&[(12, 0, 0); 5]);
651 let hi = mdn_array(&[(0, 0, 0); 5]);
652 let n = i64_array_all(5, 12);
653
654 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
655 let out = downcast_i32(&out);
656 assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
658 }
659
660 #[test]
661 fn test_width_bucket_interval_mdn_day_nano_basic() {
662 let v = mdn_array(&[
663 (0, 0, 0),
664 (0, 5, 0),
665 (0, 9, 0),
666 (0, 10, 0),
667 (0, -1, 0),
668 (0, 11, 0),
669 ]);
670 let lo = mdn_array(&[(0, 0, 0); 6]);
671 let hi = mdn_array(&[(0, 10, 0); 6]);
672 let n = i64_array_all(6, 10);
673
674 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
675 let out = downcast_i32(&out);
676 assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]);
678 }
679
680 #[test]
681 fn test_width_bucket_interval_mdn_day_nano_desc() {
682 let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
683 let lo = mdn_array(&[(0, 10, 0); 5]);
684 let hi = mdn_array(&[(0, 0, 0); 5]);
685 let n = i64_array_all(5, 10);
686
687 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
688 let out = downcast_i32(&out);
689
690 assert_eq!(out.values(), &[2, 1, 11, 11, 0]);
691 }
692 #[test]
693 fn test_width_bucket_interval_mdn_day_nano_desc_inside() {
694 let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
695 let lo = mdn_array(&[(0, 10, 0); 5]);
696 let hi = mdn_array(&[(0, 0, 0); 5]);
697 let n = i64_array_all(5, 10);
698
699 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
700 let out = downcast_i32(&out);
701
702 assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
703 }
704
705 #[test]
706 fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() {
707 let v = mdn_array(&[(0, 1, 0)]);
708 let lo = mdn_array(&[(0, 0, 0)]);
709 let hi = mdn_array(&[(1, 1, 0)]);
710 let n = i64_array_all(1, 4);
711
712 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
713 let out = downcast_i32(&out);
714 assert!(out.is_null(0));
715 }
716
717 #[test]
718 fn test_width_bucket_interval_mdn_equal_bounds_is_null() {
719 let v = mdn_array(&[(0, 0, 0)]);
720 let lo = mdn_array(&[(1, 2, 3)]);
721 let hi = mdn_array(&[(1, 2, 3)]); let n = i64_array_all(1, 10);
723
724 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
725 assert!(downcast_i32(&out).is_null(0));
726 }
727
728 #[test]
729 fn test_width_bucket_interval_mdn_invalid_n_is_null() {
730 let v = mdn_array(&[(0, 0, 0)]);
731 let lo = mdn_array(&[(0, 0, 0)]);
732 let hi = mdn_array(&[(0, 10, 0)]);
733 let n = Arc::new(Int64Array::from(vec![0])); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
736 assert!(downcast_i32(&out).is_null(0));
737 }
738
739 #[test]
740 fn test_width_bucket_interval_mdn_nulls_propagate() {
741 let v = Arc::new(IntervalMonthDayNanoArray::from(vec![
742 None,
743 Some(IntervalMonthDayNano::new(0, 5, 0)),
744 ]));
745 let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]);
746 let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]);
747 let n = i64_array_all(2, 10);
748
749 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
750 let out = downcast_i32(&out);
751 assert!(out.is_null(0));
752 assert_eq!(out.value(1), 6);
753 }
754
755 #[test]
758 fn test_width_bucket_wrong_arg_count() {
759 let v = f64_array(&[1.0]);
760 let lo = f64_array(&[0.0]);
761 let hi = f64_array(&[10.0]);
762 let err = width_bucket_kern(&[v, lo, hi]).unwrap_err();
763 let msg = format!("{err}");
764 assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}");
765 }
766
767 #[test]
768 fn test_width_bucket_unsupported_type() {
769 let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
770 let lo = f64_array(&[0.0, 0.0, 0.0]);
771 let hi = f64_array(&[10.0, 10.0, 10.0]);
772 let n = i64_array_all(3, 10);
773
774 let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
775 let msg = format!("{err}");
776 assert!(
777 msg.contains("width_bucket received unexpected data types"),
778 "unexpected error: {msg}"
779 );
780 }
781}