Skip to main content

laminar_sql/datafusion/
window_udf.rs

1//! Window function UDFs for `DataFusion` integration
2//!
3//! Provides scalar UDFs that compute window start timestamps for
4//! streaming window operations:
5//!
6//! - [`TumbleWindowStart`] — `tumble(timestamp, interval)` — fixed-size non-overlapping windows
7//! - [`HopWindowStart`] — `hop(timestamp, slide, size)` — fixed-size overlapping windows
8//! - [`SessionWindowStart`] — `session(timestamp, gap)` — pass-through for Ring 0 sessions
9//!
10//! These UDFs allow `DataFusion` to execute `GROUP BY TUMBLE(...)` style queries
11//! by computing the window start as a per-row scalar value.
12
13use std::any::Any;
14use std::hash::{Hash, Hasher};
15use std::sync::Arc;
16
17use arrow::datatypes::{DataType, TimeUnit};
18use arrow_array::{ArrayRef, TimestampMillisecondArray};
19use datafusion_common::{DataFusionError, Result, ScalarValue};
20use datafusion_expr::{
21    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
22};
23use laminar_core::time::cast_to_millis_array;
24
25// ─── TumbleWindowStart ───────────────────────────────────────────────────────
26
27/// Computes the tumbling window start for a given timestamp.
28///
29/// `tumble(timestamp, interval)` returns `floor(ts / interval) * interval`,
30/// which is the start of the non-overlapping window that contains `ts`.
31///
32/// # Arguments
33///
34/// * Arg 0: Timestamp column or scalar (`TimestampMillisecond` or `Int64` ms)
35/// * Arg 1: Window size as an interval scalar
36///
37/// # Returns
38///
39/// `TimestampMillisecond` representing the window start.
40#[derive(Debug)]
41pub struct TumbleWindowStart {
42    signature: Signature,
43}
44
45impl TumbleWindowStart {
46    /// Creates a new tumble window start UDF.
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            signature: Signature::new(
51                TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]),
52                Volatility::Immutable,
53            ),
54        }
55    }
56}
57
58impl Default for TumbleWindowStart {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl PartialEq for TumbleWindowStart {
65    fn eq(&self, _other: &Self) -> bool {
66        true // All instances are identical
67    }
68}
69
70impl Eq for TumbleWindowStart {}
71
72impl Hash for TumbleWindowStart {
73    fn hash<H: Hasher>(&self, state: &mut H) {
74        "tumble".hash(state);
75    }
76}
77
78impl ScalarUDFImpl for TumbleWindowStart {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn name(&self) -> &'static str {
84        "tumble"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
92        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
93    }
94
95    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96        let ScalarFunctionArgs { args, .. } = args;
97        if args.len() < 2 || args.len() > 3 {
98            return Err(DataFusionError::Plan(
99                "tumble() requires 2-3 arguments: (timestamp, interval [, offset])".to_string(),
100            ));
101        }
102        let interval_ms = extract_interval_ms(&args[1])?;
103        if interval_ms <= 0 {
104            return Err(DataFusionError::Plan(
105                "tumble() interval must be positive".to_string(),
106            ));
107        }
108        let offset_ms = if args.len() == 3 {
109            extract_interval_ms(&args[2])?
110        } else {
111            0
112        };
113        compute_tumble_with_offset(&args[0], interval_ms, offset_ms)
114    }
115}
116
117// ─── HopWindowStart ──────────────────────────────────────────────────────────
118
119/// Computes the earliest hopping window start for a given timestamp.
120///
121/// `hop(timestamp, slide, size)` returns the start of the earliest window
122/// (of the given `size`, sliding by `slide`) that contains `ts`.
123///
124/// # Limitation
125///
126/// This returns only the *earliest* window start. Full multi-window
127/// assignment (one row per window) is handled by Ring 0 operators.
128///
129/// # Arguments
130///
131/// * Arg 0: Timestamp column or scalar
132/// * Arg 1: Slide interval scalar
133/// * Arg 2: Window size interval scalar
134#[derive(Debug)]
135pub struct HopWindowStart {
136    signature: Signature,
137}
138
139impl HopWindowStart {
140    /// Creates a new hop window start UDF.
141    #[must_use]
142    pub fn new() -> Self {
143        Self {
144            signature: Signature::new(
145                TypeSignature::OneOf(vec![TypeSignature::Any(3), TypeSignature::Any(4)]),
146                Volatility::Immutable,
147            ),
148        }
149    }
150}
151
152impl Default for HopWindowStart {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl PartialEq for HopWindowStart {
159    fn eq(&self, _other: &Self) -> bool {
160        true
161    }
162}
163
164impl Eq for HopWindowStart {}
165
166impl Hash for HopWindowStart {
167    fn hash<H: Hasher>(&self, state: &mut H) {
168        "hop".hash(state);
169    }
170}
171
172impl ScalarUDFImpl for HopWindowStart {
173    fn as_any(&self) -> &dyn Any {
174        self
175    }
176
177    fn name(&self) -> &'static str {
178        "hop"
179    }
180
181    fn signature(&self) -> &Signature {
182        &self.signature
183    }
184
185    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
186        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
187    }
188
189    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
190        let ScalarFunctionArgs { args, .. } = args;
191        if args.len() < 3 || args.len() > 4 {
192            return Err(DataFusionError::Plan(
193                "hop() requires 3-4 arguments: (timestamp, slide, size [, offset])".to_string(),
194            ));
195        }
196        let slide_ms = extract_interval_ms(&args[1])?;
197        let size_ms = extract_interval_ms(&args[2])?;
198        if slide_ms <= 0 || size_ms <= 0 {
199            return Err(DataFusionError::Plan(
200                "hop() slide and size must be positive".to_string(),
201            ));
202        }
203        let offset_ms = if args.len() == 4 {
204            extract_interval_ms(&args[3])?
205        } else {
206            0
207        };
208        compute_hop_with_offset(&args[0], slide_ms, size_ms, offset_ms)
209    }
210}
211
212// ─── SessionWindowStart ──────────────────────────────────────────────────────
213
214/// Pass-through UDF for session window compatibility.
215///
216/// `session(timestamp, gap)` returns the input timestamp unchanged.
217/// Session windows are data-dependent (gap-based grouping) and cannot
218/// be computed as a per-row scalar. The actual session assignment is
219/// handled by Ring 0 operators.
220///
221/// This UDF exists so that `GROUP BY SESSION(ts, gap)` is syntactically
222/// valid in `DataFusion` queries, with real session logic deferred to
223/// the streaming engine.
224#[derive(Debug)]
225pub struct SessionWindowStart {
226    signature: Signature,
227}
228
229impl SessionWindowStart {
230    /// Creates a new session window start UDF.
231    #[must_use]
232    pub fn new() -> Self {
233        Self {
234            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
235        }
236    }
237}
238
239impl Default for SessionWindowStart {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245impl PartialEq for SessionWindowStart {
246    fn eq(&self, _other: &Self) -> bool {
247        true
248    }
249}
250
251impl Eq for SessionWindowStart {}
252
253impl Hash for SessionWindowStart {
254    fn hash<H: Hasher>(&self, state: &mut H) {
255        "session".hash(state);
256    }
257}
258
259impl ScalarUDFImpl for SessionWindowStart {
260    fn as_any(&self) -> &dyn Any {
261        self
262    }
263
264    fn name(&self) -> &'static str {
265        "session"
266    }
267
268    fn signature(&self) -> &Signature {
269        &self.signature
270    }
271
272    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
273        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
274    }
275
276    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
277        let ScalarFunctionArgs { args, .. } = args;
278        if args.len() != 2 {
279            return Err(DataFusionError::Plan(
280                "session() requires exactly 2 arguments: (timestamp, gap)".to_string(),
281            ));
282        }
283        // Pass-through: return the input timestamp as-is
284        match &args[0] {
285            ColumnarValue::Array(array) => {
286                let result = convert_to_timestamp_ms_array(array)?;
287                Ok(ColumnarValue::Array(result))
288            }
289            ColumnarValue::Scalar(scalar) => {
290                let ts_ms = scalar_to_timestamp_ms(scalar)?;
291                Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
292                    ts_ms, None,
293                )))
294            }
295        }
296    }
297}
298
299// ─── CumulateWindowStart ────────────────────────────────────────────────────
300
301/// Computes the cumulate window epoch start for a given timestamp.
302///
303/// `cumulate(timestamp, step, size)` returns `floor(ts / size) * size`,
304/// which is the epoch start for the cumulating window that contains `ts`.
305/// The actual multi-window assignment (one row per cumulating window) is
306/// handled by Ring 0 operators.
307///
308/// # Arguments
309///
310/// * Arg 0: Timestamp column or scalar (`TimestampMillisecond` or `Int64` ms)
311/// * Arg 1: Step interval scalar (window growth increment)
312/// * Arg 2: Max size interval scalar (epoch size)
313#[derive(Debug)]
314pub struct CumulateWindowStart {
315    signature: Signature,
316}
317
318impl CumulateWindowStart {
319    /// Creates a new cumulate window start UDF.
320    #[must_use]
321    pub fn new() -> Self {
322        Self {
323            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
324        }
325    }
326}
327
328impl Default for CumulateWindowStart {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334impl PartialEq for CumulateWindowStart {
335    fn eq(&self, _other: &Self) -> bool {
336        true
337    }
338}
339
340impl Eq for CumulateWindowStart {}
341
342impl Hash for CumulateWindowStart {
343    fn hash<H: Hasher>(&self, state: &mut H) {
344        "cumulate".hash(state);
345    }
346}
347
348impl ScalarUDFImpl for CumulateWindowStart {
349    fn as_any(&self) -> &dyn Any {
350        self
351    }
352
353    fn name(&self) -> &'static str {
354        "cumulate"
355    }
356
357    fn signature(&self) -> &Signature {
358        &self.signature
359    }
360
361    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
362        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
363    }
364
365    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
366        let ScalarFunctionArgs { args, .. } = args;
367        if args.len() != 3 {
368            return Err(DataFusionError::Plan(
369                "cumulate() requires exactly 3 arguments: (timestamp, step, size)".to_string(),
370            ));
371        }
372        let step_ms = extract_interval_ms(&args[1])?;
373        let size_ms = extract_interval_ms(&args[2])?;
374        if step_ms <= 0 || size_ms <= 0 {
375            return Err(DataFusionError::Plan(
376                "cumulate() step and size must be positive".to_string(),
377            ));
378        }
379        if step_ms > size_ms {
380            return Err(DataFusionError::Plan(
381                "cumulate() step must not exceed size".to_string(),
382            ));
383        }
384        if size_ms % step_ms != 0 {
385            return Err(DataFusionError::Plan(
386                "cumulate() size must be evenly divisible by step".to_string(),
387            ));
388        }
389        // Return epoch start = floor(ts / size) * size (same as tumble with size)
390        compute_tumble(&args[0], size_ms)
391    }
392}
393
394// ─── Helper Functions ────────────────────────────────────────────────────────
395
396/// Extracts an interval value in milliseconds from a `ColumnarValue`.
397///
398/// Only scalar intervals are supported (array intervals would require
399/// per-row window sizes, which is not a valid streaming pattern).
400fn extract_interval_ms(value: &ColumnarValue) -> Result<i64> {
401    match value {
402        ColumnarValue::Scalar(scalar) => scalar_interval_to_ms(scalar),
403        ColumnarValue::Array(_) => Err(DataFusionError::NotImplemented(
404            "Array interval arguments not supported for window functions".to_string(),
405        )),
406    }
407}
408
409/// Converts a scalar interval to milliseconds.
410fn scalar_interval_to_ms(scalar: &ScalarValue) -> Result<i64> {
411    match scalar {
412        ScalarValue::IntervalDayTime(Some(v)) => {
413            Ok(i64::from(v.days) * 86_400_000 + i64::from(v.milliseconds))
414        }
415        ScalarValue::IntervalMonthDayNano(Some(v)) => {
416            if v.months != 0 {
417                return Err(DataFusionError::NotImplemented(
418                    "Month-based intervals not supported for window functions \
419                     (use days/hours/minutes/seconds)"
420                        .to_string(),
421                ));
422            }
423            Ok(i64::from(v.days) * 86_400_000 + v.nanoseconds / 1_000_000)
424        }
425        ScalarValue::IntervalYearMonth(_) => Err(DataFusionError::NotImplemented(
426            "Year-month intervals not supported for window functions".to_string(),
427        )),
428        ScalarValue::Int64(Some(ms)) => Ok(*ms),
429        _ => Err(DataFusionError::Plan(format!(
430            "Expected interval argument for window function, got: {scalar:?}"
431        ))),
432    }
433}
434
435/// Converts a scalar value to a timestamp in milliseconds.
436fn scalar_to_timestamp_ms(scalar: &ScalarValue) -> Result<Option<i64>> {
437    match scalar {
438        ScalarValue::TimestampMillisecond(v, _) | ScalarValue::Int64(v) => Ok(*v),
439        ScalarValue::TimestampMicrosecond(v, _) => Ok(v.map(|v| v / 1_000)),
440        ScalarValue::TimestampNanosecond(v, _) => Ok(v.map(|v| v / 1_000_000)),
441        ScalarValue::TimestampSecond(v, _) => Ok(v.map(|v| v * 1_000)),
442        _ => Err(DataFusionError::Plan(format!(
443            "Expected timestamp argument for window function, got: {scalar:?}"
444        ))),
445    }
446}
447
448/// Computes tumble window start for a `ColumnarValue`.
449fn compute_tumble(value: &ColumnarValue, interval_ms: i64) -> Result<ColumnarValue> {
450    match value {
451        ColumnarValue::Array(array) => {
452            let result = compute_tumble_array(array, interval_ms)?;
453            Ok(ColumnarValue::Array(result))
454        }
455        ColumnarValue::Scalar(scalar) => {
456            let ts_ms = scalar_to_timestamp_ms(scalar)?;
457            let window_start = ts_ms.map(|ts| ts - ts.rem_euclid(interval_ms));
458            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
459                window_start,
460                None,
461            )))
462        }
463    }
464}
465
466/// Normalise a timestamp array to `TimestampMillisecond` for the window
467/// math below. Wraps the shared helper so DataFusion gets a `Plan` error
468/// on non-timestamp columns.
469fn to_millis_array(array: &ArrayRef) -> Result<TimestampMillisecondArray> {
470    cast_to_millis_array(array.as_ref()).map_err(|e| DataFusionError::Plan(e.to_string()))
471}
472
473/// Computes tumble window start for an array of timestamps.
474fn compute_tumble_array(array: &ArrayRef, interval_ms: i64) -> Result<ArrayRef> {
475    let input = to_millis_array(array)?;
476    let result: TimestampMillisecondArray = input
477        .iter()
478        .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
479        .collect();
480    Ok(Arc::new(result))
481}
482
483/// Computes tumble window start with offset for a `ColumnarValue`.
484fn compute_tumble_with_offset(
485    value: &ColumnarValue,
486    interval_ms: i64,
487    offset_ms: i64,
488) -> Result<ColumnarValue> {
489    if offset_ms == 0 {
490        return compute_tumble(value, interval_ms);
491    }
492    match value {
493        ColumnarValue::Array(array) => {
494            let result = compute_tumble_array_with_offset(array, interval_ms, offset_ms)?;
495            Ok(ColumnarValue::Array(result))
496        }
497        ColumnarValue::Scalar(scalar) => {
498            let ts_ms = scalar_to_timestamp_ms(scalar)?;
499            let window_start = ts_ms.map(|ts| {
500                let adj = ts - offset_ms;
501                (adj - adj.rem_euclid(interval_ms)) + offset_ms
502            });
503            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
504                window_start,
505                None,
506            )))
507        }
508    }
509}
510
511/// Computes tumble window start with offset for an array of timestamps.
512fn compute_tumble_array_with_offset(
513    array: &ArrayRef,
514    interval_ms: i64,
515    offset_ms: i64,
516) -> Result<ArrayRef> {
517    let input = to_millis_array(array)?;
518    let result: TimestampMillisecondArray = input
519        .iter()
520        .map(|opt_ts| {
521            opt_ts.map(|ts| {
522                let adj = ts - offset_ms;
523                (adj - adj.rem_euclid(interval_ms)) + offset_ms
524            })
525        })
526        .collect();
527    Ok(Arc::new(result))
528}
529
530/// Computes hop (earliest) window start for a `ColumnarValue`.
531fn compute_hop(value: &ColumnarValue, slide_ms: i64, size_ms: i64) -> Result<ColumnarValue> {
532    match value {
533        ColumnarValue::Array(array) => {
534            let result = compute_hop_array(array, slide_ms, size_ms)?;
535            Ok(ColumnarValue::Array(result))
536        }
537        ColumnarValue::Scalar(scalar) => {
538            let ts_ms = scalar_to_timestamp_ms(scalar)?;
539            let window_start = ts_ms.map(|ts| hop_earliest_start(ts, slide_ms, size_ms));
540            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
541                window_start,
542                None,
543            )))
544        }
545    }
546}
547
548/// Computes hop window start for an array of timestamps.
549fn compute_hop_array(array: &ArrayRef, slide_ms: i64, size_ms: i64) -> Result<ArrayRef> {
550    let input = to_millis_array(array)?;
551    let result: TimestampMillisecondArray = input
552        .iter()
553        .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
554        .collect();
555    Ok(Arc::new(result))
556}
557
558/// Computes the earliest window start for a hopping window containing `ts`.
559///
560/// Windows of `size_ms` slide by `slide_ms`. The earliest window that
561/// contains `ts` starts at `floor((ts - size + slide) / slide) * slide`.
562#[inline]
563fn hop_earliest_start(ts: i64, slide_ms: i64, size_ms: i64) -> i64 {
564    let adjusted = ts - size_ms + slide_ms;
565    adjusted - adjusted.rem_euclid(slide_ms)
566}
567
568/// Computes hop (earliest) window start with offset.
569fn compute_hop_with_offset(
570    value: &ColumnarValue,
571    slide_ms: i64,
572    size_ms: i64,
573    offset_ms: i64,
574) -> Result<ColumnarValue> {
575    if offset_ms == 0 {
576        return compute_hop(value, slide_ms, size_ms);
577    }
578    match value {
579        ColumnarValue::Array(array) => {
580            let result = compute_hop_array_with_offset(array, slide_ms, size_ms, offset_ms)?;
581            Ok(ColumnarValue::Array(result))
582        }
583        ColumnarValue::Scalar(scalar) => {
584            let ts_ms = scalar_to_timestamp_ms(scalar)?;
585            let window_start =
586                ts_ms.map(|ts| hop_earliest_start(ts - offset_ms, slide_ms, size_ms) + offset_ms);
587            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
588                window_start,
589                None,
590            )))
591        }
592    }
593}
594
595/// Computes hop window start with offset for an array.
596fn compute_hop_array_with_offset(
597    array: &ArrayRef,
598    slide_ms: i64,
599    size_ms: i64,
600    offset_ms: i64,
601) -> Result<ArrayRef> {
602    let input = to_millis_array(array)?;
603    let result: TimestampMillisecondArray = input
604        .iter()
605        .map(|opt_ts| {
606            opt_ts.map(|ts| hop_earliest_start(ts - offset_ms, slide_ms, size_ms) + offset_ms)
607        })
608        .collect();
609    Ok(Arc::new(result))
610}
611
612/// Converts a timestamp array to `TimestampMillisecond` for consistent output.
613fn convert_to_timestamp_ms_array(array: &ArrayRef) -> Result<ArrayRef> {
614    let ms = to_millis_array(array)?;
615    Ok(Arc::new(ms))
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621    use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano, TimestampMillisecondType};
622    use arrow_array::cast::AsArray;
623    use arrow_array::Array;
624    use arrow_schema::Field;
625    use datafusion_common::config::ConfigOptions;
626    use datafusion_expr::ScalarUDF;
627
628    fn interval_dt(days: i32, ms: i32) -> ColumnarValue {
629        ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(
630            days, ms,
631        ))))
632    }
633
634    fn ts_ms(ms: Option<i64>) -> ColumnarValue {
635        ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(ms, None))
636    }
637
638    fn expect_ts_ms(result: ColumnarValue) -> Option<i64> {
639        match result {
640            ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, _)) => v,
641            other => panic!("Expected TimestampMillisecond scalar, got: {other:?}"),
642        }
643    }
644
645    fn make_args(args: Vec<ColumnarValue>, rows: usize) -> ScalarFunctionArgs {
646        ScalarFunctionArgs {
647            args,
648            arg_fields: vec![],
649            number_rows: rows,
650            return_field: Arc::new(Field::new(
651                "output",
652                DataType::Timestamp(TimeUnit::Millisecond, None),
653                true,
654            )),
655            config_options: Arc::new(ConfigOptions::default()),
656        }
657    }
658
659    // ── Tumble tests ─────────────────────────────────────────────────────
660
661    #[test]
662    fn test_tumble_basic() {
663        let udf = TumbleWindowStart::new();
664        // 5-minute interval = 300_000 ms, timestamp at 7 min
665        let result = udf
666            .invoke_with_args(make_args(
667                vec![ts_ms(Some(420_000)), interval_dt(0, 300_000)],
668                1,
669            ))
670            .unwrap();
671        assert_eq!(expect_ts_ms(result), Some(300_000));
672    }
673
674    #[test]
675    fn test_tumble_exact_boundary() {
676        let udf = TumbleWindowStart::new();
677        let result = udf
678            .invoke_with_args(make_args(
679                vec![ts_ms(Some(300_000)), interval_dt(0, 300_000)],
680                1,
681            ))
682            .unwrap();
683        assert_eq!(expect_ts_ms(result), Some(300_000));
684    }
685
686    #[test]
687    fn test_tumble_zero_timestamp() {
688        let udf = TumbleWindowStart::new();
689        let result = udf
690            .invoke_with_args(make_args(vec![ts_ms(Some(0)), interval_dt(0, 300_000)], 1))
691            .unwrap();
692        assert_eq!(expect_ts_ms(result), Some(0));
693    }
694
695    #[test]
696    fn test_tumble_null_handling() {
697        let udf = TumbleWindowStart::new();
698        let result = udf
699            .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 300_000)], 1))
700            .unwrap();
701        assert_eq!(expect_ts_ms(result), None);
702    }
703
704    #[test]
705    fn test_tumble_array_input() {
706        let udf = TumbleWindowStart::new();
707        let ts_array = TimestampMillisecondArray::from(vec![
708            Some(0),
709            Some(150_000),
710            Some(300_000),
711            Some(420_000),
712            None,
713        ]);
714        let ts = ColumnarValue::Array(Arc::new(ts_array));
715        let interval = interval_dt(0, 300_000);
716
717        let result = udf
718            .invoke_with_args(make_args(vec![ts, interval], 5))
719            .unwrap();
720        match result {
721            ColumnarValue::Array(arr) => {
722                let r = arr.as_primitive::<TimestampMillisecondType>();
723                assert_eq!(r.value(0), 0);
724                assert_eq!(r.value(1), 0);
725                assert_eq!(r.value(2), 300_000);
726                assert_eq!(r.value(3), 300_000);
727                assert!(r.is_null(4));
728            }
729            ColumnarValue::Scalar(_) => panic!("Expected array result"),
730        }
731    }
732
733    /// Regression: TUMBLE over a `Timestamp(Nanosecond)` column used to
734    /// bail out with "Unsupported timestamp type for tumble()" because
735    /// the array fast path only handled `Timestamp(Millisecond)` and
736    /// `Int64`. `to_millis_array` now casts any Timestamp precision.
737    #[test]
738    fn test_tumble_array_input_nanosecond() {
739        use arrow_array::TimestampNanosecondArray;
740
741        let udf = TumbleWindowStart::new();
742        // 0s, 150s, 300s, 420s in ns.
743        let ts_array = TimestampNanosecondArray::from(vec![
744            Some(0),
745            Some(150_000_000_000),
746            Some(300_000_000_000),
747            Some(420_000_000_000),
748            None,
749        ]);
750        let ts = ColumnarValue::Array(Arc::new(ts_array));
751        let interval = interval_dt(0, 300_000); // 5 minutes
752
753        let result = udf
754            .invoke_with_args(make_args(vec![ts, interval], 5))
755            .unwrap();
756        match result {
757            ColumnarValue::Array(arr) => {
758                let r = arr.as_primitive::<TimestampMillisecondType>();
759                assert_eq!(r.value(0), 0);
760                assert_eq!(r.value(1), 0);
761                assert_eq!(r.value(2), 300_000);
762                assert_eq!(r.value(3), 300_000);
763                assert!(r.is_null(4));
764            }
765            ColumnarValue::Scalar(_) => panic!("Expected array result"),
766        }
767    }
768
769    /// Regression: HOP over a `Timestamp(Nanosecond)` column — same
770    /// missing arm as TUMBLE, fixed by the same `to_millis_array` helper.
771    #[test]
772    fn test_hop_array_input_nanosecond() {
773        use arrow_array::TimestampNanosecondArray;
774
775        let udf = HopWindowStart::new();
776        // 7 minutes in ns.
777        let ts_array = TimestampNanosecondArray::from(vec![Some(420_000_000_000)]);
778        let ts = ColumnarValue::Array(Arc::new(ts_array));
779        // slide=5min, size=10min
780        let result = udf
781            .invoke_with_args(make_args(
782                vec![ts, interval_dt(0, 300_000), interval_dt(0, 600_000)],
783                1,
784            ))
785            .unwrap();
786        match result {
787            ColumnarValue::Array(arr) => {
788                let r = arr.as_primitive::<TimestampMillisecondType>();
789                assert_eq!(r.value(0), 0);
790            }
791            ColumnarValue::Scalar(_) => panic!("Expected array result"),
792        }
793    }
794
795    #[test]
796    fn test_tumble_month_day_nano_interval() {
797        let udf = TumbleWindowStart::new();
798        // 1 hour = 3_600_000_000_000 nanoseconds
799        let interval = ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(
800            IntervalMonthDayNano::new(0, 0, 3_600_000_000_000),
801        )));
802        // 90 minutes = 5_400_000 ms
803        let result = udf
804            .invoke_with_args(make_args(vec![ts_ms(Some(5_400_000)), interval], 1))
805            .unwrap();
806        assert_eq!(expect_ts_ms(result), Some(3_600_000));
807    }
808
809    #[test]
810    fn test_tumble_rejects_zero_interval() {
811        let udf = TumbleWindowStart::new();
812        let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000)), interval_dt(0, 0)], 1));
813        assert!(result.is_err());
814    }
815
816    #[test]
817    fn test_tumble_rejects_wrong_arg_count() {
818        let udf = TumbleWindowStart::new();
819        let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000))], 1));
820        assert!(result.is_err());
821    }
822
823    // ── Hop tests ────────────────────────────────────────────────────────
824
825    #[test]
826    fn test_hop_basic() {
827        let udf = HopWindowStart::new();
828        // slide=5min, size=10min, ts=7min
829        let result = udf
830            .invoke_with_args(make_args(
831                vec![
832                    ts_ms(Some(420_000)),
833                    interval_dt(0, 300_000),
834                    interval_dt(0, 600_000),
835                ],
836                1,
837            ))
838            .unwrap();
839        // Earliest 10-min window (sliding 5min) containing 420_000:
840        // adjusted = 420_000 - 600_000 + 300_000 = 120_000
841        // 120_000 - (120_000 % 300_000) = 120_000 - 120_000 = 0
842        assert_eq!(expect_ts_ms(result), Some(0));
843    }
844
845    #[test]
846    fn test_hop_at_boundary() {
847        let udf = HopWindowStart::new();
848        // slide=5min, size=10min, ts=exactly 5min
849        let result = udf
850            .invoke_with_args(make_args(
851                vec![
852                    ts_ms(Some(300_000)),
853                    interval_dt(0, 300_000),
854                    interval_dt(0, 600_000),
855                ],
856                1,
857            ))
858            .unwrap();
859        // adjusted = 300_000 - 600_000 + 300_000 = 0
860        // 0 - (0 % 300_000) = 0
861        assert_eq!(expect_ts_ms(result), Some(0));
862    }
863
864    #[test]
865    fn test_hop_rejects_wrong_arg_count() {
866        let udf = HopWindowStart::new();
867        let result = udf.invoke_with_args(make_args(
868            vec![ts_ms(Some(1000)), interval_dt(0, 300_000)],
869            1,
870        ));
871        assert!(result.is_err());
872    }
873
874    // ── Session tests ────────────────────────────────────────────────────
875
876    #[test]
877    fn test_session_passthrough_scalar() {
878        let udf = SessionWindowStart::new();
879        let result = udf
880            .invoke_with_args(make_args(
881                vec![ts_ms(Some(42_000)), interval_dt(0, 60_000)],
882                1,
883            ))
884            .unwrap();
885        assert_eq!(expect_ts_ms(result), Some(42_000));
886    }
887
888    #[test]
889    fn test_session_passthrough_null() {
890        let udf = SessionWindowStart::new();
891        let result = udf
892            .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 60_000)], 1))
893            .unwrap();
894        assert_eq!(expect_ts_ms(result), None);
895    }
896
897    // ── Registration & signature tests ───────────────────────────────────
898
899    // ── Cumulate tests ──────────────────────────────────────────────────
900
901    #[test]
902    fn test_cumulate_basic() {
903        let udf = CumulateWindowStart::new();
904        // step=1min, size=5min, ts=30s → epoch start = 0
905        let result = udf
906            .invoke_with_args(make_args(
907                vec![
908                    ts_ms(Some(30_000)),
909                    interval_dt(0, 60_000),
910                    interval_dt(0, 300_000),
911                ],
912                1,
913            ))
914            .unwrap();
915        assert_eq!(expect_ts_ms(result), Some(0));
916    }
917
918    #[test]
919    fn test_cumulate_second_epoch() {
920        let udf = CumulateWindowStart::new();
921        // step=1min, size=5min, ts=350s → epoch start = 300_000
922        let result = udf
923            .invoke_with_args(make_args(
924                vec![
925                    ts_ms(Some(350_000)),
926                    interval_dt(0, 60_000),
927                    interval_dt(0, 300_000),
928                ],
929                1,
930            ))
931            .unwrap();
932        assert_eq!(expect_ts_ms(result), Some(300_000));
933    }
934
935    #[test]
936    fn test_cumulate_rejects_step_exceeds_size() {
937        let udf = CumulateWindowStart::new();
938        let result = udf.invoke_with_args(make_args(
939            vec![
940                ts_ms(Some(1000)),
941                interval_dt(0, 600_000),
942                interval_dt(0, 300_000),
943            ],
944            1,
945        ));
946        assert!(result.is_err());
947    }
948
949    #[test]
950    fn test_cumulate_rejects_not_divisible() {
951        let udf = CumulateWindowStart::new();
952        let result = udf.invoke_with_args(make_args(
953            vec![
954                ts_ms(Some(1000)),
955                interval_dt(0, 70_000),
956                interval_dt(0, 300_000),
957            ],
958            1,
959        ));
960        assert!(result.is_err());
961    }
962
963    #[test]
964    fn test_cumulate_rejects_wrong_arg_count() {
965        let udf = CumulateWindowStart::new();
966        let result = udf.invoke_with_args(make_args(
967            vec![ts_ms(Some(1000)), interval_dt(0, 60_000)],
968            1,
969        ));
970        assert!(result.is_err());
971    }
972
973    // ── Registration & signature tests ───────────────────────────────────
974
975    #[test]
976    fn test_udf_registration() {
977        let tumble = ScalarUDF::new_from_impl(TumbleWindowStart::new());
978        assert_eq!(tumble.name(), "tumble");
979
980        let hop = ScalarUDF::new_from_impl(HopWindowStart::new());
981        assert_eq!(hop.name(), "hop");
982
983        let session = ScalarUDF::new_from_impl(SessionWindowStart::new());
984        assert_eq!(session.name(), "session");
985
986        let cumulate = ScalarUDF::new_from_impl(CumulateWindowStart::new());
987        assert_eq!(cumulate.name(), "cumulate");
988    }
989
990    #[test]
991    fn test_udf_signatures_immutable() {
992        assert_eq!(
993            TumbleWindowStart::new().signature().volatility,
994            Volatility::Immutable
995        );
996        assert_eq!(
997            HopWindowStart::new().signature().volatility,
998            Volatility::Immutable
999        );
1000        assert_eq!(
1001            SessionWindowStart::new().signature().volatility,
1002            Volatility::Immutable
1003        );
1004        assert_eq!(
1005            CumulateWindowStart::new().signature().volatility,
1006            Volatility::Immutable
1007        );
1008    }
1009
1010    #[test]
1011    fn test_tumble_return_type() {
1012        let udf = TumbleWindowStart::new();
1013        let rt = udf.return_type(&[]).unwrap();
1014        assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
1015    }
1016
1017    #[test]
1018    fn test_cumulate_return_type() {
1019        let udf = CumulateWindowStart::new();
1020        let rt = udf.return_type(&[]).unwrap();
1021        assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
1022    }
1023}