1use std::any::Any;
14use std::hash::{Hash, Hasher};
15use std::sync::Arc;
16
17use arrow::datatypes::{DataType, Int64Type, TimeUnit, TimestampMillisecondType};
18use arrow_array::cast::AsArray;
19use arrow_array::{ArrayRef, TimestampMillisecondArray};
20use datafusion_common::{DataFusionError, Result, ScalarValue};
21use datafusion_expr::{
22 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
23};
24
25#[derive(Debug)]
41pub struct TumbleWindowStart {
42 signature: Signature,
43}
44
45impl TumbleWindowStart {
46 #[must_use]
48 pub fn new() -> Self {
49 Self {
50 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
51 }
52 }
53}
54
55impl Default for TumbleWindowStart {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl PartialEq for TumbleWindowStart {
62 fn eq(&self, _other: &Self) -> bool {
63 true }
65}
66
67impl Eq for TumbleWindowStart {}
68
69impl Hash for TumbleWindowStart {
70 fn hash<H: Hasher>(&self, state: &mut H) {
71 "tumble".hash(state);
72 }
73}
74
75impl ScalarUDFImpl for TumbleWindowStart {
76 fn as_any(&self) -> &dyn Any {
77 self
78 }
79
80 fn name(&self) -> &'static str {
81 "tumble"
82 }
83
84 fn signature(&self) -> &Signature {
85 &self.signature
86 }
87
88 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
89 Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
90 }
91
92 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93 let ScalarFunctionArgs { args, .. } = args;
94 if args.len() != 2 {
95 return Err(DataFusionError::Plan(
96 "tumble() requires exactly 2 arguments: (timestamp, interval)".to_string(),
97 ));
98 }
99 let interval_ms = extract_interval_ms(&args[1])?;
100 if interval_ms <= 0 {
101 return Err(DataFusionError::Plan(
102 "tumble() interval must be positive".to_string(),
103 ));
104 }
105 compute_tumble(&args[0], interval_ms)
106 }
107}
108
109#[derive(Debug)]
127pub struct HopWindowStart {
128 signature: Signature,
129}
130
131impl HopWindowStart {
132 #[must_use]
134 pub fn new() -> Self {
135 Self {
136 signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
137 }
138 }
139}
140
141impl Default for HopWindowStart {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl PartialEq for HopWindowStart {
148 fn eq(&self, _other: &Self) -> bool {
149 true
150 }
151}
152
153impl Eq for HopWindowStart {}
154
155impl Hash for HopWindowStart {
156 fn hash<H: Hasher>(&self, state: &mut H) {
157 "hop".hash(state);
158 }
159}
160
161impl ScalarUDFImpl for HopWindowStart {
162 fn as_any(&self) -> &dyn Any {
163 self
164 }
165
166 fn name(&self) -> &'static str {
167 "hop"
168 }
169
170 fn signature(&self) -> &Signature {
171 &self.signature
172 }
173
174 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
175 Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
176 }
177
178 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
179 let ScalarFunctionArgs { args, .. } = args;
180 if args.len() != 3 {
181 return Err(DataFusionError::Plan(
182 "hop() requires exactly 3 arguments: (timestamp, slide, size)".to_string(),
183 ));
184 }
185 let slide_ms = extract_interval_ms(&args[1])?;
186 let size_ms = extract_interval_ms(&args[2])?;
187 if slide_ms <= 0 || size_ms <= 0 {
188 return Err(DataFusionError::Plan(
189 "hop() slide and size must be positive".to_string(),
190 ));
191 }
192 compute_hop(&args[0], slide_ms, size_ms)
193 }
194}
195
196#[derive(Debug)]
209pub struct SessionWindowStart {
210 signature: Signature,
211}
212
213impl SessionWindowStart {
214 #[must_use]
216 pub fn new() -> Self {
217 Self {
218 signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
219 }
220 }
221}
222
223impl Default for SessionWindowStart {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229impl PartialEq for SessionWindowStart {
230 fn eq(&self, _other: &Self) -> bool {
231 true
232 }
233}
234
235impl Eq for SessionWindowStart {}
236
237impl Hash for SessionWindowStart {
238 fn hash<H: Hasher>(&self, state: &mut H) {
239 "session".hash(state);
240 }
241}
242
243impl ScalarUDFImpl for SessionWindowStart {
244 fn as_any(&self) -> &dyn Any {
245 self
246 }
247
248 fn name(&self) -> &'static str {
249 "session"
250 }
251
252 fn signature(&self) -> &Signature {
253 &self.signature
254 }
255
256 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
257 Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
258 }
259
260 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
261 let ScalarFunctionArgs { args, .. } = args;
262 if args.len() != 2 {
263 return Err(DataFusionError::Plan(
264 "session() requires exactly 2 arguments: (timestamp, gap)".to_string(),
265 ));
266 }
267 match &args[0] {
269 ColumnarValue::Array(array) => {
270 let result = convert_to_timestamp_ms_array(array)?;
271 Ok(ColumnarValue::Array(result))
272 }
273 ColumnarValue::Scalar(scalar) => {
274 let ts_ms = scalar_to_timestamp_ms(scalar)?;
275 Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
276 ts_ms, None,
277 )))
278 }
279 }
280 }
281}
282
283fn extract_interval_ms(value: &ColumnarValue) -> Result<i64> {
290 match value {
291 ColumnarValue::Scalar(scalar) => scalar_interval_to_ms(scalar),
292 ColumnarValue::Array(_) => Err(DataFusionError::NotImplemented(
293 "Array interval arguments not supported for window functions".to_string(),
294 )),
295 }
296}
297
298fn scalar_interval_to_ms(scalar: &ScalarValue) -> Result<i64> {
300 match scalar {
301 ScalarValue::IntervalDayTime(Some(v)) => {
302 Ok(i64::from(v.days) * 86_400_000 + i64::from(v.milliseconds))
303 }
304 ScalarValue::IntervalMonthDayNano(Some(v)) => {
305 if v.months != 0 {
306 return Err(DataFusionError::NotImplemented(
307 "Month-based intervals not supported for window functions \
308 (use days/hours/minutes/seconds)"
309 .to_string(),
310 ));
311 }
312 Ok(i64::from(v.days) * 86_400_000 + v.nanoseconds / 1_000_000)
313 }
314 ScalarValue::IntervalYearMonth(_) => Err(DataFusionError::NotImplemented(
315 "Year-month intervals not supported for window functions".to_string(),
316 )),
317 ScalarValue::Int64(Some(ms)) => Ok(*ms),
318 _ => Err(DataFusionError::Plan(format!(
319 "Expected interval argument for window function, got: {scalar:?}"
320 ))),
321 }
322}
323
324fn scalar_to_timestamp_ms(scalar: &ScalarValue) -> Result<Option<i64>> {
326 match scalar {
327 ScalarValue::TimestampMillisecond(v, _) | ScalarValue::Int64(v) => Ok(*v),
328 ScalarValue::TimestampMicrosecond(v, _) => Ok(v.map(|v| v / 1_000)),
329 ScalarValue::TimestampNanosecond(v, _) => Ok(v.map(|v| v / 1_000_000)),
330 ScalarValue::TimestampSecond(v, _) => Ok(v.map(|v| v * 1_000)),
331 _ => Err(DataFusionError::Plan(format!(
332 "Expected timestamp argument for window function, got: {scalar:?}"
333 ))),
334 }
335}
336
337fn compute_tumble(value: &ColumnarValue, interval_ms: i64) -> Result<ColumnarValue> {
339 match value {
340 ColumnarValue::Array(array) => {
341 let result = compute_tumble_array(array, interval_ms)?;
342 Ok(ColumnarValue::Array(result))
343 }
344 ColumnarValue::Scalar(scalar) => {
345 let ts_ms = scalar_to_timestamp_ms(scalar)?;
346 let window_start = ts_ms.map(|ts| ts - ts.rem_euclid(interval_ms));
347 Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
348 window_start,
349 None,
350 )))
351 }
352 }
353}
354
355fn compute_tumble_array(array: &ArrayRef, interval_ms: i64) -> Result<ArrayRef> {
357 match array.data_type() {
358 DataType::Timestamp(TimeUnit::Millisecond, _) => {
359 let input = array.as_primitive::<TimestampMillisecondType>();
360 let result: TimestampMillisecondArray = input
361 .iter()
362 .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
363 .collect();
364 Ok(Arc::new(result))
365 }
366 DataType::Int64 => {
367 let input = array.as_primitive::<Int64Type>();
368 let result: TimestampMillisecondArray = input
369 .iter()
370 .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
371 .collect();
372 Ok(Arc::new(result))
373 }
374 other => Err(DataFusionError::Plan(format!(
375 "Unsupported timestamp type for tumble(): {other:?}. \
376 Use TimestampMillisecond or Int64."
377 ))),
378 }
379}
380
381fn compute_hop(value: &ColumnarValue, slide_ms: i64, size_ms: i64) -> Result<ColumnarValue> {
383 match value {
384 ColumnarValue::Array(array) => {
385 let result = compute_hop_array(array, slide_ms, size_ms)?;
386 Ok(ColumnarValue::Array(result))
387 }
388 ColumnarValue::Scalar(scalar) => {
389 let ts_ms = scalar_to_timestamp_ms(scalar)?;
390 let window_start = ts_ms.map(|ts| hop_earliest_start(ts, slide_ms, size_ms));
391 Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
392 window_start,
393 None,
394 )))
395 }
396 }
397}
398
399fn compute_hop_array(array: &ArrayRef, slide_ms: i64, size_ms: i64) -> Result<ArrayRef> {
401 match array.data_type() {
402 DataType::Timestamp(TimeUnit::Millisecond, _) => {
403 let input = array.as_primitive::<TimestampMillisecondType>();
404 let result: TimestampMillisecondArray = input
405 .iter()
406 .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
407 .collect();
408 Ok(Arc::new(result))
409 }
410 DataType::Int64 => {
411 let input = array.as_primitive::<Int64Type>();
412 let result: TimestampMillisecondArray = input
413 .iter()
414 .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
415 .collect();
416 Ok(Arc::new(result))
417 }
418 other => Err(DataFusionError::Plan(format!(
419 "Unsupported timestamp type for hop(): {other:?}. \
420 Use TimestampMillisecond or Int64."
421 ))),
422 }
423}
424
425#[inline]
430fn hop_earliest_start(ts: i64, slide_ms: i64, size_ms: i64) -> i64 {
431 let adjusted = ts - size_ms + slide_ms;
432 adjusted - adjusted.rem_euclid(slide_ms)
433}
434
435fn convert_to_timestamp_ms_array(array: &ArrayRef) -> Result<ArrayRef> {
437 match array.data_type() {
438 DataType::Timestamp(TimeUnit::Millisecond, _) => Ok(Arc::clone(array)),
439 DataType::Int64 => {
440 let input = array.as_primitive::<Int64Type>();
441 let result: TimestampMillisecondArray = input.iter().collect();
442 Ok(Arc::new(result))
443 }
444 other => Err(DataFusionError::Plan(format!(
445 "Unsupported timestamp type for session(): {other:?}. \
446 Use TimestampMillisecond or Int64."
447 ))),
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454 use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
455 use arrow_array::Array;
456 use arrow_schema::Field;
457 use datafusion_common::config::ConfigOptions;
458 use datafusion_expr::ScalarUDF;
459
460 fn interval_dt(days: i32, ms: i32) -> ColumnarValue {
461 ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(
462 days, ms,
463 ))))
464 }
465
466 fn ts_ms(ms: Option<i64>) -> ColumnarValue {
467 ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(ms, None))
468 }
469
470 fn expect_ts_ms(result: ColumnarValue) -> Option<i64> {
471 match result {
472 ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, _)) => v,
473 other => panic!("Expected TimestampMillisecond scalar, got: {other:?}"),
474 }
475 }
476
477 fn make_args(args: Vec<ColumnarValue>, rows: usize) -> ScalarFunctionArgs {
478 ScalarFunctionArgs {
479 args,
480 arg_fields: vec![],
481 number_rows: rows,
482 return_field: Arc::new(Field::new(
483 "output",
484 DataType::Timestamp(TimeUnit::Millisecond, None),
485 true,
486 )),
487 config_options: Arc::new(ConfigOptions::default()),
488 }
489 }
490
491 #[test]
494 fn test_tumble_basic() {
495 let udf = TumbleWindowStart::new();
496 let result = udf
498 .invoke_with_args(make_args(
499 vec![ts_ms(Some(420_000)), interval_dt(0, 300_000)],
500 1,
501 ))
502 .unwrap();
503 assert_eq!(expect_ts_ms(result), Some(300_000));
504 }
505
506 #[test]
507 fn test_tumble_exact_boundary() {
508 let udf = TumbleWindowStart::new();
509 let result = udf
510 .invoke_with_args(make_args(
511 vec![ts_ms(Some(300_000)), interval_dt(0, 300_000)],
512 1,
513 ))
514 .unwrap();
515 assert_eq!(expect_ts_ms(result), Some(300_000));
516 }
517
518 #[test]
519 fn test_tumble_zero_timestamp() {
520 let udf = TumbleWindowStart::new();
521 let result = udf
522 .invoke_with_args(make_args(vec![ts_ms(Some(0)), interval_dt(0, 300_000)], 1))
523 .unwrap();
524 assert_eq!(expect_ts_ms(result), Some(0));
525 }
526
527 #[test]
528 fn test_tumble_null_handling() {
529 let udf = TumbleWindowStart::new();
530 let result = udf
531 .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 300_000)], 1))
532 .unwrap();
533 assert_eq!(expect_ts_ms(result), None);
534 }
535
536 #[test]
537 fn test_tumble_array_input() {
538 let udf = TumbleWindowStart::new();
539 let ts_array = TimestampMillisecondArray::from(vec![
540 Some(0),
541 Some(150_000),
542 Some(300_000),
543 Some(420_000),
544 None,
545 ]);
546 let ts = ColumnarValue::Array(Arc::new(ts_array));
547 let interval = interval_dt(0, 300_000);
548
549 let result = udf
550 .invoke_with_args(make_args(vec![ts, interval], 5))
551 .unwrap();
552 match result {
553 ColumnarValue::Array(arr) => {
554 let r = arr.as_primitive::<TimestampMillisecondType>();
555 assert_eq!(r.value(0), 0);
556 assert_eq!(r.value(1), 0);
557 assert_eq!(r.value(2), 300_000);
558 assert_eq!(r.value(3), 300_000);
559 assert!(r.is_null(4));
560 }
561 ColumnarValue::Scalar(_) => panic!("Expected array result"),
562 }
563 }
564
565 #[test]
566 fn test_tumble_month_day_nano_interval() {
567 let udf = TumbleWindowStart::new();
568 let interval = ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(
570 IntervalMonthDayNano::new(0, 0, 3_600_000_000_000),
571 )));
572 let result = udf
574 .invoke_with_args(make_args(vec![ts_ms(Some(5_400_000)), interval], 1))
575 .unwrap();
576 assert_eq!(expect_ts_ms(result), Some(3_600_000));
577 }
578
579 #[test]
580 fn test_tumble_rejects_zero_interval() {
581 let udf = TumbleWindowStart::new();
582 let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000)), interval_dt(0, 0)], 1));
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_tumble_rejects_wrong_arg_count() {
588 let udf = TumbleWindowStart::new();
589 let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000))], 1));
590 assert!(result.is_err());
591 }
592
593 #[test]
596 fn test_hop_basic() {
597 let udf = HopWindowStart::new();
598 let result = udf
600 .invoke_with_args(make_args(
601 vec![
602 ts_ms(Some(420_000)),
603 interval_dt(0, 300_000),
604 interval_dt(0, 600_000),
605 ],
606 1,
607 ))
608 .unwrap();
609 assert_eq!(expect_ts_ms(result), Some(0));
613 }
614
615 #[test]
616 fn test_hop_at_boundary() {
617 let udf = HopWindowStart::new();
618 let result = udf
620 .invoke_with_args(make_args(
621 vec![
622 ts_ms(Some(300_000)),
623 interval_dt(0, 300_000),
624 interval_dt(0, 600_000),
625 ],
626 1,
627 ))
628 .unwrap();
629 assert_eq!(expect_ts_ms(result), Some(0));
632 }
633
634 #[test]
635 fn test_hop_rejects_wrong_arg_count() {
636 let udf = HopWindowStart::new();
637 let result = udf.invoke_with_args(make_args(
638 vec![ts_ms(Some(1000)), interval_dt(0, 300_000)],
639 1,
640 ));
641 assert!(result.is_err());
642 }
643
644 #[test]
647 fn test_session_passthrough_scalar() {
648 let udf = SessionWindowStart::new();
649 let result = udf
650 .invoke_with_args(make_args(
651 vec![ts_ms(Some(42_000)), interval_dt(0, 60_000)],
652 1,
653 ))
654 .unwrap();
655 assert_eq!(expect_ts_ms(result), Some(42_000));
656 }
657
658 #[test]
659 fn test_session_passthrough_null() {
660 let udf = SessionWindowStart::new();
661 let result = udf
662 .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 60_000)], 1))
663 .unwrap();
664 assert_eq!(expect_ts_ms(result), None);
665 }
666
667 #[test]
670 fn test_udf_registration() {
671 let tumble = ScalarUDF::new_from_impl(TumbleWindowStart::new());
672 assert_eq!(tumble.name(), "tumble");
673
674 let hop = ScalarUDF::new_from_impl(HopWindowStart::new());
675 assert_eq!(hop.name(), "hop");
676
677 let session = ScalarUDF::new_from_impl(SessionWindowStart::new());
678 assert_eq!(session.name(), "session");
679 }
680
681 #[test]
682 fn test_udf_signatures_immutable() {
683 assert_eq!(
684 TumbleWindowStart::new().signature().volatility,
685 Volatility::Immutable
686 );
687 assert_eq!(
688 HopWindowStart::new().signature().volatility,
689 Volatility::Immutable
690 );
691 assert_eq!(
692 SessionWindowStart::new().signature().volatility,
693 Volatility::Immutable
694 );
695 }
696
697 #[test]
698 fn test_tumble_return_type() {
699 let udf = TumbleWindowStart::new();
700 let rt = udf.return_type(&[]).unwrap();
701 assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
702 }
703}