1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray,
23 IntervalYearMonthArray,
24};
25use arrow::datatypes::DataType;
26use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval};
27use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth};
28use datafusion_common::cast::{
29 as_duration_microsecond_array, as_float64_array, as_int32_array,
30 as_interval_mdn_array, as_interval_ym_array,
31};
32use datafusion_common::types::{
33 NativeType, logical_duration_microsecond, logical_float64, logical_int32,
34 logical_interval_mdn, logical_interval_year_month,
35};
36use datafusion_common::{Result, exec_err, internal_err};
37use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
38use datafusion_expr::{
39 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
40 TypeSignatureClass,
41};
42use datafusion_functions::utils::make_scalar_function;
43
44use arrow::array::{Int32Array, Int32Builder};
45use arrow::datatypes::TimeUnit::Microsecond;
46use datafusion_expr::Coercion;
47use datafusion_expr::Volatility::Immutable;
48
49#[derive(Debug, PartialEq, Eq, Hash)]
50pub struct SparkWidthBucket {
51 signature: Signature,
52}
53
54impl Default for SparkWidthBucket {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl SparkWidthBucket {
61 pub fn new() -> Self {
62 let numeric = Coercion::new_implicit(
63 TypeSignatureClass::Native(logical_float64()),
64 vec![TypeSignatureClass::Numeric],
65 NativeType::Float64,
66 );
67 let duration = Coercion::new_implicit(
68 TypeSignatureClass::Native(logical_duration_microsecond()),
69 vec![TypeSignatureClass::Duration],
70 NativeType::Duration(Microsecond),
71 );
72 let interval_ym = Coercion::new_exact(TypeSignatureClass::Native(
73 logical_interval_year_month(),
74 ));
75 let interval_mdn =
76 Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn()));
77 let bucket = Coercion::new_implicit(
78 TypeSignatureClass::Native(logical_int32()),
79 vec![TypeSignatureClass::Integer],
80 NativeType::Int32,
81 );
82 let type_signature = Signature::one_of(
83 vec![
84 TypeSignature::Coercible(vec![
85 numeric.clone(),
86 numeric.clone(),
87 numeric.clone(),
88 bucket.clone(),
89 ]),
90 TypeSignature::Coercible(vec![
91 duration.clone(),
92 duration.clone(),
93 duration.clone(),
94 bucket.clone(),
95 ]),
96 TypeSignature::Coercible(vec![
97 interval_ym.clone(),
98 interval_ym.clone(),
99 interval_ym.clone(),
100 bucket.clone(),
101 ]),
102 TypeSignature::Coercible(vec![
103 interval_mdn.clone(),
104 interval_mdn.clone(),
105 interval_mdn.clone(),
106 bucket.clone(),
107 ]),
108 ],
109 Immutable,
110 )
111 .with_parameter_names(vec!["expr", "min", "max", "num_buckets"])
112 .expect("valid parameter names");
113 Self {
114 signature: type_signature,
115 }
116 }
117}
118
119impl ScalarUDFImpl for SparkWidthBucket {
120 fn as_any(&self) -> &dyn Any {
121 self
122 }
123
124 fn name(&self) -> &str {
125 "width_bucket"
126 }
127
128 fn signature(&self) -> &Signature {
129 &self.signature
130 }
131
132 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
133 Ok(Int32)
134 }
135
136 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
137 make_scalar_function(width_bucket_kern, vec![])(&args.args)
138 }
139
140 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
141 if input.len() == 1 {
142 let value = &input[0];
143 Ok(value.sort_properties)
144 } else {
145 Ok(SortProperties::default())
146 }
147 }
148}
149
150fn width_bucket_kern(args: &[ArrayRef]) -> Result<ArrayRef> {
151 let [v, minv, maxv, nb] = args else {
152 return exec_err!(
153 "width_bucket expects exactly 4 argument, got {}",
154 args.len()
155 );
156 };
157
158 match v.data_type() {
159 Float64 => {
160 let v = as_float64_array(v)?;
161 let min = as_float64_array(minv)?;
162 let max = as_float64_array(maxv)?;
163 let n_bucket = as_int32_array(nb)?;
164 Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket)))
165 }
166 Duration(Microsecond) => {
167 let v = as_duration_microsecond_array(v)?;
168 let min = as_duration_microsecond_array(minv)?;
169 let max = as_duration_microsecond_array(maxv)?;
170 let n_bucket = as_int32_array(nb)?;
171 Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket)))
172 }
173 Interval(YearMonth) => {
174 let v = as_interval_ym_array(v)?;
175 let min = as_interval_ym_array(minv)?;
176 let max = as_interval_ym_array(maxv)?;
177 let n_bucket = as_int32_array(nb)?;
178 Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket)))
179 }
180 Interval(MonthDayNano) => {
181 let v = as_interval_mdn_array(v)?;
182 let min = as_interval_mdn_array(minv)?;
183 let max = as_interval_mdn_array(maxv)?;
184 let n_bucket = as_int32_array(nb)?;
185 Ok(Arc::new(width_bucket_interval_mdn_exact(
186 v, min, max, n_bucket,
187 )))
188 }
189
190 other => internal_err!(
191 "width_bucket received unexpected data types: {:?}, {:?}, {:?}, {:?}",
192 other,
193 minv.data_type(),
194 maxv.data_type(),
195 nb.data_type()
196 ),
197 }
198}
199
200macro_rules! width_bucket_kernel_impl {
201 ($name:ident, $arr_ty:ty, $to_f64:expr, $check_nan:expr) => {
202 pub(crate) fn $name(
203 v: &$arr_ty,
204 min: &$arr_ty,
205 max: &$arr_ty,
206 n_bucket: &Int32Array,
207 ) -> Int32Array {
208 let len = v.len();
209 let mut b = Int32Builder::with_capacity(len);
210
211 for i in 0..len {
212 if v.is_null(i) || min.is_null(i) || max.is_null(i) || n_bucket.is_null(i)
213 {
214 b.append_null();
215 continue;
216 }
217 let x = ($to_f64)(v, i);
218 let l = ($to_f64)(min, i);
219 let h = ($to_f64)(max, i);
220 let buckets = n_bucket.value(i);
221
222 if buckets <= 0 {
223 b.append_null();
224 continue;
225 }
226 if $check_nan {
227 if !x.is_finite() || !l.is_finite() || !h.is_finite() {
228 b.append_null();
229 continue;
230 }
231 }
232
233 let ord = match l.partial_cmp(&h) {
234 Some(o) => o,
235 None => {
236 b.append_null();
237 continue;
238 }
239 };
240 if matches!(ord, std::cmp::Ordering::Equal) {
241 b.append_null();
242 continue;
243 }
244 let asc = matches!(ord, std::cmp::Ordering::Less);
245
246 if asc {
247 if x < l {
248 b.append_value(0);
249 continue;
250 }
251 if x >= h {
252 b.append_value(buckets + 1);
253 continue;
254 }
255 } else {
256 if x > l {
257 b.append_value(0);
258 continue;
259 }
260 if x <= h {
261 b.append_value(buckets + 1);
262 continue;
263 }
264 }
265
266 let width = (h - l) / (buckets as f64);
267 if width == 0.0 || !width.is_finite() {
268 b.append_null();
269 continue;
270 }
271 let mut bucket = ((x - l) / width).floor() as i32 + 1;
272 if bucket < 1 {
273 bucket = 1;
274 }
275 if bucket > buckets + 1 {
276 bucket = buckets + 1;
277 }
278
279 b.append_value(bucket);
280 }
281
282 b.finish()
283 }
284 };
285}
286
287width_bucket_kernel_impl!(
288 width_bucket_float64,
289 Float64Array,
290 |arr: &Float64Array, i: usize| arr.value(i),
291 true
292);
293
294width_bucket_kernel_impl!(
295 width_bucket_i64_as_float,
296 DurationMicrosecondArray,
297 |arr: &DurationMicrosecondArray, i: usize| arr.value(i) as f64,
298 false
299);
300
301width_bucket_kernel_impl!(
302 width_bucket_i32_as_float,
303 IntervalYearMonthArray,
304 |arr: &IntervalYearMonthArray, i: usize| arr.value(i) as f64,
305 false
306);
307const NS_PER_DAY_I128: i128 = 86_400_000_000_000;
308pub(crate) fn width_bucket_interval_mdn_exact(
309 v: &IntervalMonthDayNanoArray,
310 lo: &IntervalMonthDayNanoArray,
311 hi: &IntervalMonthDayNanoArray,
312 n: &Int32Array,
313) -> Int32Array {
314 let len = v.len();
315 let mut b = Int32Builder::with_capacity(len);
316
317 for i in 0..len {
318 if v.is_null(i) || lo.is_null(i) || hi.is_null(i) || n.is_null(i) {
319 b.append_null();
320 continue;
321 }
322 let buckets = n.value(i);
323 if buckets <= 0 {
324 b.append_null();
325 continue;
326 }
327
328 let x = v.value(i);
329 let l = lo.value(i);
330 let h = hi.value(i);
331
332 let asc = (l.months, l.days, l.nanoseconds) < (h.months, h.days, h.nanoseconds);
335 if (l.months, l.days, l.nanoseconds) == (h.months, h.days, h.nanoseconds) {
336 b.append_null();
337 continue;
338 }
339
340 if l.days == h.days && l.nanoseconds == h.nanoseconds && l.months != h.months {
342 let x_m = x.months as f64;
343 let l_m = l.months as f64;
344 let h_m = h.months as f64;
345
346 if asc {
347 if x_m < l_m {
348 b.append_value(0);
349 continue;
350 }
351 if x_m >= h_m {
352 b.append_value(buckets + 1);
353 continue;
354 }
355 } else {
356 if x_m > l_m {
357 b.append_value(0);
358 continue;
359 }
360 if x_m <= h_m {
361 b.append_value(buckets + 1);
362 continue;
363 }
364 }
365
366 let width = (h_m - l_m) / (buckets as f64);
367 if width == 0.0 || !width.is_finite() {
368 b.append_null();
369 continue;
370 }
371
372 let mut bucket = ((x_m - l_m) / width).floor() as i32 + 1;
373 if bucket < 1 {
374 bucket = 1;
375 }
376 if bucket > buckets + 1 {
377 bucket = buckets + 1;
378 }
379 b.append_value(bucket);
380 continue;
381 }
382
383 if l.months == h.months {
385 let base_days = l.days as i128;
386 let base_ns = l.nanoseconds as i128;
387
388 let xf = (x.days as i128 - base_days) * NS_PER_DAY_I128
389 + (x.nanoseconds as i128 - base_ns);
390 let hf = (h.days as i128 - base_days) * NS_PER_DAY_I128
391 + (h.nanoseconds as i128 - base_ns);
392
393 let x_f = xf as f64;
394 let l_f = 0.0;
395 let h_f = hf as f64;
396
397 if asc {
398 if x_f < l_f {
399 b.append_value(0);
400 continue;
401 }
402 if x_f >= h_f {
403 b.append_value(buckets + 1);
404 continue;
405 }
406 } else {
407 if x_f > l_f {
408 b.append_value(0);
409 continue;
410 }
411 if x_f <= h_f {
412 b.append_value(buckets + 1);
413 continue;
414 }
415 }
416
417 let width = (h_f - l_f) / (buckets as f64);
418 if width == 0.0 || !width.is_finite() {
419 b.append_null();
420 continue;
421 }
422
423 let mut bucket = ((x_f - l_f) / width).floor() as i32 + 1;
424 if bucket < 1 {
425 bucket = 1;
426 }
427 if bucket > buckets + 1 {
428 bucket = buckets + 1;
429 }
430 b.append_value(bucket);
431 continue;
432 }
433
434 b.append_null();
435 }
436
437 b.finish()
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use std::sync::Arc;
444
445 use arrow::array::{
446 ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array,
447 IntervalYearMonthArray,
448 };
449 use arrow::datatypes::IntervalMonthDayNano;
450
451 fn i32_array_all(len: usize, val: i32) -> Arc<Int32Array> {
454 Arc::new(Int32Array::from(vec![val; len]))
455 }
456
457 fn f64_array(vals: &[f64]) -> Arc<Float64Array> {
458 Arc::new(Float64Array::from(vals.to_vec()))
459 }
460
461 fn f64_array_opt(vals: &[Option<f64>]) -> Arc<Float64Array> {
462 Arc::new(Float64Array::from(vals.to_vec()))
463 }
464
465 fn dur_us_array(vals: &[i64]) -> Arc<DurationMicrosecondArray> {
466 Arc::new(DurationMicrosecondArray::from(vals.to_vec()))
467 }
468
469 fn ym_array(vals: &[i32]) -> Arc<IntervalYearMonthArray> {
470 Arc::new(IntervalYearMonthArray::from(vals.to_vec()))
471 }
472
473 fn downcast_i32(arr: &ArrayRef) -> &Int32Array {
474 arr.as_any().downcast_ref::<Int32Array>().unwrap()
475 }
476
477 fn mdn_array(vals: &[(i32, i32, i64)]) -> Arc<IntervalMonthDayNanoArray> {
478 let data: Vec<IntervalMonthDayNano> = vals
479 .iter()
480 .map(|(m, d, ns)| IntervalMonthDayNano::new(*m, *d, *ns))
481 .collect();
482 Arc::new(IntervalMonthDayNanoArray::from(data))
483 }
484
485 #[test]
488 fn test_width_bucket_f64_basic() {
489 let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]);
490 let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]);
491 let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]);
492 let n = i32_array_all(5, 10);
493
494 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
495 let out = downcast_i32(&out);
496 assert_eq!(out.values(), &[1, 2, 10, 0, 11]);
497 }
498
499 #[test]
500 fn test_width_bucket_f64_descending_range() {
501 let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]);
502 let lo = f64_array(&[10.0; 5]);
503 let hi = f64_array(&[0.0; 5]);
504 let n = i32_array_all(5, 10);
505
506 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
507 let out = downcast_i32(&out);
508
509 assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
510 }
511 #[test]
512 fn test_width_bucket_f64_bounds_inclusive_exclusive_asc() {
513 let v = f64_array(&[0.0, 9.999999999, 10.0]);
514 let lo = f64_array(&[0.0; 3]);
515 let hi = f64_array(&[10.0; 3]);
516 let n = i32_array_all(3, 10);
517
518 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
519 let out = downcast_i32(&out);
520 assert_eq!(out.values(), &[1, 10, 11]);
521 }
522
523 #[test]
524 fn test_width_bucket_f64_bounds_inclusive_exclusive_desc() {
525 let v = f64_array(&[10.0, 0.0, -0.000001]);
526 let lo = f64_array(&[10.0; 3]);
527 let hi = f64_array(&[0.0; 3]);
528 let n = i32_array_all(3, 10);
529
530 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
531 let out = downcast_i32(&out);
532 assert_eq!(out.values(), &[1, 11, 11]);
533 }
534
535 #[test]
536 fn test_width_bucket_f64_edge_cases() {
537 let v = f64_array(&[1.0, 5.0, 9.0]);
538 let lo = f64_array(&[0.0, 0.0, 0.0]);
539 let hi = f64_array(&[10.0, 10.0, 10.0]);
540 let n = Arc::new(Int32Array::from(vec![0, -1, 10]));
541 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
542 let out = downcast_i32(&out);
543 assert!(out.is_null(0));
544 assert!(out.is_null(1));
545 assert_eq!(out.value(2), 10);
546
547 let v = f64_array(&[1.0]);
548 let lo = f64_array(&[5.0]);
549 let hi = f64_array(&[5.0]);
550 let n = i32_array_all(1, 10);
551 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
552 let out = downcast_i32(&out);
553 assert!(out.is_null(0));
554
555 let v = f64_array_opt(&[Some(f64::NAN)]);
556 let lo = f64_array(&[0.0]);
557 let hi = f64_array(&[10.0]);
558 let n = i32_array_all(1, 10);
559 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
560 let out = downcast_i32(&out);
561 assert!(out.is_null(0));
562 }
563
564 #[test]
565 fn test_width_bucket_f64_nulls_propagate() {
566 let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]);
567 let lo = f64_array(&[0.0; 4]);
568 let hi = f64_array(&[10.0; 4]);
569 let n = i32_array_all(4, 10);
570
571 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
572 let out = downcast_i32(&out);
573 assert!(out.is_null(0));
574 assert_eq!(out.value(1), 2);
575 assert_eq!(out.value(2), 3);
576 assert_eq!(out.value(3), 4);
577
578 let v = f64_array(&[1.0]);
579 let lo = f64_array_opt(&[None]);
580 let hi = f64_array(&[10.0]);
581 let n = i32_array_all(1, 10);
582 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
583 let out = downcast_i32(&out);
584 assert!(out.is_null(0));
585 }
586
587 #[test]
590 fn test_width_bucket_duration_us() {
591 let v = dur_us_array(&[1_000_000, 0, -1]);
592 let lo = dur_us_array(&[0, 0, 0]);
593 let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]);
594 let n = i32_array_all(3, 2);
595
596 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
597 let out = downcast_i32(&out);
598 assert_eq!(out.values(), &[2, 1, 0]);
599 }
600
601 #[test]
602 fn test_width_bucket_duration_us_equal_bounds() {
603 let v = dur_us_array(&[0]);
604 let lo = dur_us_array(&[1]);
605 let hi = dur_us_array(&[1]);
606 let n = i32_array_all(1, 10);
607 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
608 assert!(downcast_i32(&out).is_null(0));
609 }
610
611 #[test]
614 fn test_width_bucket_interval_ym_basic() {
615 let v = ym_array(&[0, 5, 11, 12, 13]);
616 let lo = ym_array(&[0; 5]);
617 let hi = ym_array(&[12; 5]);
618 let n = i32_array_all(5, 12);
619
620 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
621 let out = downcast_i32(&out);
622 assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
623 }
624
625 #[test]
626 fn test_width_bucket_interval_ym_desc() {
627 let v = ym_array(&[11, 12, 0, -1, 13]);
628 let lo = ym_array(&[12; 5]);
629 let hi = ym_array(&[0; 5]);
630 let n = i32_array_all(5, 12);
631
632 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
633 let out = downcast_i32(&out);
634 assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
635 }
636
637 #[test]
640 fn test_width_bucket_interval_mdn_months_only_basic() {
641 let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]);
642 let lo = mdn_array(&[(0, 0, 0); 5]);
643 let hi = mdn_array(&[(12, 0, 0); 5]);
644 let n = i32_array_all(5, 12);
645
646 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
647 let out = downcast_i32(&out);
648 assert_eq!(out.values(), &[1, 6, 12, 13, 13]);
649 }
650
651 #[test]
652 fn test_width_bucket_interval_mdn_months_only_desc() {
653 let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]);
654 let lo = mdn_array(&[(12, 0, 0); 5]);
655 let hi = mdn_array(&[(0, 0, 0); 5]);
656 let n = i32_array_all(5, 12);
657
658 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
659 let out = downcast_i32(&out);
660 assert_eq!(out.values(), &[2, 1, 13, 13, 0]);
662 }
663
664 #[test]
665 fn test_width_bucket_interval_mdn_day_nano_basic() {
666 let v = mdn_array(&[
667 (0, 0, 0),
668 (0, 5, 0),
669 (0, 9, 0),
670 (0, 10, 0),
671 (0, -1, 0),
672 (0, 11, 0),
673 ]);
674 let lo = mdn_array(&[(0, 0, 0); 6]);
675 let hi = mdn_array(&[(0, 10, 0); 6]);
676 let n = i32_array_all(6, 10);
677
678 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
679 let out = downcast_i32(&out);
680 assert_eq!(out.values(), &[1, 6, 10, 11, 0, 11]);
682 }
683
684 #[test]
685 fn test_width_bucket_interval_mdn_day_nano_desc() {
686 let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
687 let lo = mdn_array(&[(0, 10, 0); 5]);
688 let hi = mdn_array(&[(0, 0, 0); 5]);
689 let n = i32_array_all(5, 10);
690
691 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
692 let out = downcast_i32(&out);
693
694 assert_eq!(out.values(), &[2, 1, 11, 11, 0]);
695 }
696 #[test]
697 fn test_width_bucket_interval_mdn_day_nano_desc_inside() {
698 let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]);
699 let lo = mdn_array(&[(0, 10, 0); 5]);
700 let hi = mdn_array(&[(0, 0, 0); 5]);
701 let n = i32_array_all(5, 10);
702
703 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
704 let out = downcast_i32(&out);
705
706 assert_eq!(out.values(), &[1, 1, 11, 11, 0]);
707 }
708
709 #[test]
710 fn test_width_bucket_interval_mdn_mixed_months_and_days_is_null() {
711 let v = mdn_array(&[(0, 1, 0)]);
712 let lo = mdn_array(&[(0, 0, 0)]);
713 let hi = mdn_array(&[(1, 1, 0)]);
714 let n = i32_array_all(1, 4);
715
716 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
717 let out = downcast_i32(&out);
718 assert!(out.is_null(0));
719 }
720
721 #[test]
722 fn test_width_bucket_interval_mdn_equal_bounds_is_null() {
723 let v = mdn_array(&[(0, 0, 0)]);
724 let lo = mdn_array(&[(1, 2, 3)]);
725 let hi = mdn_array(&[(1, 2, 3)]); let n = i32_array_all(1, 10);
727
728 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
729 assert!(downcast_i32(&out).is_null(0));
730 }
731
732 #[test]
733 fn test_width_bucket_interval_mdn_invalid_n_is_null() {
734 let v = mdn_array(&[(0, 0, 0)]);
735 let lo = mdn_array(&[(0, 0, 0)]);
736 let hi = mdn_array(&[(0, 10, 0)]);
737 let n = Arc::new(Int32Array::from(vec![0])); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
740 assert!(downcast_i32(&out).is_null(0));
741 }
742
743 #[test]
744 fn test_width_bucket_interval_mdn_nulls_propagate() {
745 let v = Arc::new(IntervalMonthDayNanoArray::from(vec![
746 None,
747 Some(IntervalMonthDayNano::new(0, 5, 0)),
748 ]));
749 let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]);
750 let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]);
751 let n = i32_array_all(2, 10);
752
753 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap();
754 let out = downcast_i32(&out);
755 assert!(out.is_null(0));
756 assert_eq!(out.value(1), 6);
757 }
758
759 #[test]
762 fn test_width_bucket_wrong_arg_count() {
763 let v = f64_array(&[1.0]);
764 let lo = f64_array(&[0.0]);
765 let hi = f64_array(&[10.0]);
766 let err = width_bucket_kern(&[v, lo, hi]).unwrap_err();
767 let msg = format!("{err}");
768 assert!(msg.contains("expects exactly 4"), "unexpected error: {msg}");
769 }
770
771 #[test]
772 fn test_width_bucket_unsupported_type() {
773 let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
774 let lo = f64_array(&[0.0, 0.0, 0.0]);
775 let hi = f64_array(&[10.0, 10.0, 10.0]);
776 let n = i32_array_all(3, 10);
777
778 let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err();
779 let msg = format!("{err}");
780 assert!(
781 msg.contains("width_bucket received unexpected data types"),
782 "unexpected error: {msg}"
783 );
784 }
785}