Skip to main content

laminar_sql/datafusion/
window_udf.rs

1//! Window function UDFs for `DataFusion` integration
2//!
3//! Provides scalar UDFs that compute window start timestamps for
4//! streaming window operations:
5//!
6//! - [`TumbleWindowStart`] — `tumble(timestamp, interval)` — fixed-size non-overlapping windows
7//! - [`HopWindowStart`] — `hop(timestamp, slide, size)` — fixed-size overlapping windows
8//! - [`SessionWindowStart`] — `session(timestamp, gap)` — pass-through for Ring 0 sessions
9//!
10//! These UDFs allow `DataFusion` to execute `GROUP BY TUMBLE(...)` style queries
11//! by computing the window start as a per-row scalar value.
12
13use std::any::Any;
14use std::hash::{Hash, Hasher};
15use std::sync::Arc;
16
17use arrow::datatypes::{DataType, Int64Type, TimeUnit, TimestampMillisecondType};
18use arrow_array::cast::AsArray;
19use arrow_array::{ArrayRef, TimestampMillisecondArray};
20use datafusion_common::{DataFusionError, Result, ScalarValue};
21use datafusion_expr::{
22    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
23};
24
25// ─── TumbleWindowStart ───────────────────────────────────────────────────────
26
27/// Computes the tumbling window start for a given timestamp.
28///
29/// `tumble(timestamp, interval)` returns `floor(ts / interval) * interval`,
30/// which is the start of the non-overlapping window that contains `ts`.
31///
32/// # Arguments
33///
34/// * Arg 0: Timestamp column or scalar (`TimestampMillisecond` or `Int64` ms)
35/// * Arg 1: Window size as an interval scalar
36///
37/// # Returns
38///
39/// `TimestampMillisecond` representing the window start.
40#[derive(Debug)]
41pub struct TumbleWindowStart {
42    signature: Signature,
43}
44
45impl TumbleWindowStart {
46    /// Creates a new tumble window start UDF.
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
51        }
52    }
53}
54
55impl Default for TumbleWindowStart {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl PartialEq for TumbleWindowStart {
62    fn eq(&self, _other: &Self) -> bool {
63        true // All instances are identical
64    }
65}
66
67impl Eq for TumbleWindowStart {}
68
69impl Hash for TumbleWindowStart {
70    fn hash<H: Hasher>(&self, state: &mut H) {
71        "tumble".hash(state);
72    }
73}
74
75impl ScalarUDFImpl for TumbleWindowStart {
76    fn as_any(&self) -> &dyn Any {
77        self
78    }
79
80    fn name(&self) -> &'static str {
81        "tumble"
82    }
83
84    fn signature(&self) -> &Signature {
85        &self.signature
86    }
87
88    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
89        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
90    }
91
92    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
93        let ScalarFunctionArgs { args, .. } = args;
94        if args.len() != 2 {
95            return Err(DataFusionError::Plan(
96                "tumble() requires exactly 2 arguments: (timestamp, interval)".to_string(),
97            ));
98        }
99        let interval_ms = extract_interval_ms(&args[1])?;
100        if interval_ms <= 0 {
101            return Err(DataFusionError::Plan(
102                "tumble() interval must be positive".to_string(),
103            ));
104        }
105        compute_tumble(&args[0], interval_ms)
106    }
107}
108
109// ─── HopWindowStart ──────────────────────────────────────────────────────────
110
111/// Computes the earliest hopping window start for a given timestamp.
112///
113/// `hop(timestamp, slide, size)` returns the start of the earliest window
114/// (of the given `size`, sliding by `slide`) that contains `ts`.
115///
116/// # Limitation
117///
118/// This returns only the *earliest* window start. Full multi-window
119/// assignment (one row per window) is handled by Ring 0 operators.
120///
121/// # Arguments
122///
123/// * Arg 0: Timestamp column or scalar
124/// * Arg 1: Slide interval scalar
125/// * Arg 2: Window size interval scalar
126#[derive(Debug)]
127pub struct HopWindowStart {
128    signature: Signature,
129}
130
131impl HopWindowStart {
132    /// Creates a new hop window start UDF.
133    #[must_use]
134    pub fn new() -> Self {
135        Self {
136            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
137        }
138    }
139}
140
141impl Default for HopWindowStart {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl PartialEq for HopWindowStart {
148    fn eq(&self, _other: &Self) -> bool {
149        true
150    }
151}
152
153impl Eq for HopWindowStart {}
154
155impl Hash for HopWindowStart {
156    fn hash<H: Hasher>(&self, state: &mut H) {
157        "hop".hash(state);
158    }
159}
160
161impl ScalarUDFImpl for HopWindowStart {
162    fn as_any(&self) -> &dyn Any {
163        self
164    }
165
166    fn name(&self) -> &'static str {
167        "hop"
168    }
169
170    fn signature(&self) -> &Signature {
171        &self.signature
172    }
173
174    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
175        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
176    }
177
178    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
179        let ScalarFunctionArgs { args, .. } = args;
180        if args.len() != 3 {
181            return Err(DataFusionError::Plan(
182                "hop() requires exactly 3 arguments: (timestamp, slide, size)".to_string(),
183            ));
184        }
185        let slide_ms = extract_interval_ms(&args[1])?;
186        let size_ms = extract_interval_ms(&args[2])?;
187        if slide_ms <= 0 || size_ms <= 0 {
188            return Err(DataFusionError::Plan(
189                "hop() slide and size must be positive".to_string(),
190            ));
191        }
192        compute_hop(&args[0], slide_ms, size_ms)
193    }
194}
195
196// ─── SessionWindowStart ──────────────────────────────────────────────────────
197
198/// Pass-through UDF for session window compatibility.
199///
200/// `session(timestamp, gap)` returns the input timestamp unchanged.
201/// Session windows are data-dependent (gap-based grouping) and cannot
202/// be computed as a per-row scalar. The actual session assignment is
203/// handled by Ring 0 operators.
204///
205/// This UDF exists so that `GROUP BY SESSION(ts, gap)` is syntactically
206/// valid in `DataFusion` queries, with real session logic deferred to
207/// the streaming engine.
208#[derive(Debug)]
209pub struct SessionWindowStart {
210    signature: Signature,
211}
212
213impl SessionWindowStart {
214    /// Creates a new session window start UDF.
215    #[must_use]
216    pub fn new() -> Self {
217        Self {
218            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
219        }
220    }
221}
222
223impl Default for SessionWindowStart {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl PartialEq for SessionWindowStart {
230    fn eq(&self, _other: &Self) -> bool {
231        true
232    }
233}
234
235impl Eq for SessionWindowStart {}
236
237impl Hash for SessionWindowStart {
238    fn hash<H: Hasher>(&self, state: &mut H) {
239        "session".hash(state);
240    }
241}
242
243impl ScalarUDFImpl for SessionWindowStart {
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247
248    fn name(&self) -> &'static str {
249        "session"
250    }
251
252    fn signature(&self) -> &Signature {
253        &self.signature
254    }
255
256    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
257        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
258    }
259
260    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
261        let ScalarFunctionArgs { args, .. } = args;
262        if args.len() != 2 {
263            return Err(DataFusionError::Plan(
264                "session() requires exactly 2 arguments: (timestamp, gap)".to_string(),
265            ));
266        }
267        // Pass-through: return the input timestamp as-is
268        match &args[0] {
269            ColumnarValue::Array(array) => {
270                let result = convert_to_timestamp_ms_array(array)?;
271                Ok(ColumnarValue::Array(result))
272            }
273            ColumnarValue::Scalar(scalar) => {
274                let ts_ms = scalar_to_timestamp_ms(scalar)?;
275                Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
276                    ts_ms, None,
277                )))
278            }
279        }
280    }
281}
282
283// ─── CumulateWindowStart ────────────────────────────────────────────────────
284
285/// Computes the cumulate window epoch start for a given timestamp.
286///
287/// `cumulate(timestamp, step, size)` returns `floor(ts / size) * size`,
288/// which is the epoch start for the cumulating window that contains `ts`.
289/// The actual multi-window assignment (one row per cumulating window) is
290/// handled by Ring 0 operators.
291///
292/// # Arguments
293///
294/// * Arg 0: Timestamp column or scalar (`TimestampMillisecond` or `Int64` ms)
295/// * Arg 1: Step interval scalar (window growth increment)
296/// * Arg 2: Max size interval scalar (epoch size)
297#[derive(Debug)]
298pub struct CumulateWindowStart {
299    signature: Signature,
300}
301
302impl CumulateWindowStart {
303    /// Creates a new cumulate window start UDF.
304    #[must_use]
305    pub fn new() -> Self {
306        Self {
307            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
308        }
309    }
310}
311
312impl Default for CumulateWindowStart {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318impl PartialEq for CumulateWindowStart {
319    fn eq(&self, _other: &Self) -> bool {
320        true
321    }
322}
323
324impl Eq for CumulateWindowStart {}
325
326impl Hash for CumulateWindowStart {
327    fn hash<H: Hasher>(&self, state: &mut H) {
328        "cumulate".hash(state);
329    }
330}
331
332impl ScalarUDFImpl for CumulateWindowStart {
333    fn as_any(&self) -> &dyn Any {
334        self
335    }
336
337    fn name(&self) -> &'static str {
338        "cumulate"
339    }
340
341    fn signature(&self) -> &Signature {
342        &self.signature
343    }
344
345    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
346        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
347    }
348
349    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
350        let ScalarFunctionArgs { args, .. } = args;
351        if args.len() != 3 {
352            return Err(DataFusionError::Plan(
353                "cumulate() requires exactly 3 arguments: (timestamp, step, size)".to_string(),
354            ));
355        }
356        let step_ms = extract_interval_ms(&args[1])?;
357        let size_ms = extract_interval_ms(&args[2])?;
358        if step_ms <= 0 || size_ms <= 0 {
359            return Err(DataFusionError::Plan(
360                "cumulate() step and size must be positive".to_string(),
361            ));
362        }
363        if step_ms > size_ms {
364            return Err(DataFusionError::Plan(
365                "cumulate() step must not exceed size".to_string(),
366            ));
367        }
368        if size_ms % step_ms != 0 {
369            return Err(DataFusionError::Plan(
370                "cumulate() size must be evenly divisible by step".to_string(),
371            ));
372        }
373        // Return epoch start = floor(ts / size) * size (same as tumble with size)
374        compute_tumble(&args[0], size_ms)
375    }
376}
377
378// ─── Helper Functions ────────────────────────────────────────────────────────
379
380/// Extracts an interval value in milliseconds from a `ColumnarValue`.
381///
382/// Only scalar intervals are supported (array intervals would require
383/// per-row window sizes, which is not a valid streaming pattern).
384fn extract_interval_ms(value: &ColumnarValue) -> Result<i64> {
385    match value {
386        ColumnarValue::Scalar(scalar) => scalar_interval_to_ms(scalar),
387        ColumnarValue::Array(_) => Err(DataFusionError::NotImplemented(
388            "Array interval arguments not supported for window functions".to_string(),
389        )),
390    }
391}
392
393/// Converts a scalar interval to milliseconds.
394fn scalar_interval_to_ms(scalar: &ScalarValue) -> Result<i64> {
395    match scalar {
396        ScalarValue::IntervalDayTime(Some(v)) => {
397            Ok(i64::from(v.days) * 86_400_000 + i64::from(v.milliseconds))
398        }
399        ScalarValue::IntervalMonthDayNano(Some(v)) => {
400            if v.months != 0 {
401                return Err(DataFusionError::NotImplemented(
402                    "Month-based intervals not supported for window functions \
403                     (use days/hours/minutes/seconds)"
404                        .to_string(),
405                ));
406            }
407            Ok(i64::from(v.days) * 86_400_000 + v.nanoseconds / 1_000_000)
408        }
409        ScalarValue::IntervalYearMonth(_) => Err(DataFusionError::NotImplemented(
410            "Year-month intervals not supported for window functions".to_string(),
411        )),
412        ScalarValue::Int64(Some(ms)) => Ok(*ms),
413        _ => Err(DataFusionError::Plan(format!(
414            "Expected interval argument for window function, got: {scalar:?}"
415        ))),
416    }
417}
418
419/// Converts a scalar value to a timestamp in milliseconds.
420fn scalar_to_timestamp_ms(scalar: &ScalarValue) -> Result<Option<i64>> {
421    match scalar {
422        ScalarValue::TimestampMillisecond(v, _) | ScalarValue::Int64(v) => Ok(*v),
423        ScalarValue::TimestampMicrosecond(v, _) => Ok(v.map(|v| v / 1_000)),
424        ScalarValue::TimestampNanosecond(v, _) => Ok(v.map(|v| v / 1_000_000)),
425        ScalarValue::TimestampSecond(v, _) => Ok(v.map(|v| v * 1_000)),
426        _ => Err(DataFusionError::Plan(format!(
427            "Expected timestamp argument for window function, got: {scalar:?}"
428        ))),
429    }
430}
431
432/// Computes tumble window start for a `ColumnarValue`.
433fn compute_tumble(value: &ColumnarValue, interval_ms: i64) -> Result<ColumnarValue> {
434    match value {
435        ColumnarValue::Array(array) => {
436            let result = compute_tumble_array(array, interval_ms)?;
437            Ok(ColumnarValue::Array(result))
438        }
439        ColumnarValue::Scalar(scalar) => {
440            let ts_ms = scalar_to_timestamp_ms(scalar)?;
441            let window_start = ts_ms.map(|ts| ts - ts.rem_euclid(interval_ms));
442            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
443                window_start,
444                None,
445            )))
446        }
447    }
448}
449
450/// Computes tumble window start for an array of timestamps.
451fn compute_tumble_array(array: &ArrayRef, interval_ms: i64) -> Result<ArrayRef> {
452    match array.data_type() {
453        DataType::Timestamp(TimeUnit::Millisecond, _) => {
454            let input = array.as_primitive::<TimestampMillisecondType>();
455            let result: TimestampMillisecondArray = input
456                .iter()
457                .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
458                .collect();
459            Ok(Arc::new(result))
460        }
461        DataType::Int64 => {
462            let input = array.as_primitive::<Int64Type>();
463            let result: TimestampMillisecondArray = input
464                .iter()
465                .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
466                .collect();
467            Ok(Arc::new(result))
468        }
469        other => Err(DataFusionError::Plan(format!(
470            "Unsupported timestamp type for tumble(): {other:?}. \
471             Use TimestampMillisecond or Int64."
472        ))),
473    }
474}
475
476/// Computes hop (earliest) window start for a `ColumnarValue`.
477fn compute_hop(value: &ColumnarValue, slide_ms: i64, size_ms: i64) -> Result<ColumnarValue> {
478    match value {
479        ColumnarValue::Array(array) => {
480            let result = compute_hop_array(array, slide_ms, size_ms)?;
481            Ok(ColumnarValue::Array(result))
482        }
483        ColumnarValue::Scalar(scalar) => {
484            let ts_ms = scalar_to_timestamp_ms(scalar)?;
485            let window_start = ts_ms.map(|ts| hop_earliest_start(ts, slide_ms, size_ms));
486            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
487                window_start,
488                None,
489            )))
490        }
491    }
492}
493
494/// Computes hop window start for an array of timestamps.
495fn compute_hop_array(array: &ArrayRef, slide_ms: i64, size_ms: i64) -> Result<ArrayRef> {
496    match array.data_type() {
497        DataType::Timestamp(TimeUnit::Millisecond, _) => {
498            let input = array.as_primitive::<TimestampMillisecondType>();
499            let result: TimestampMillisecondArray = input
500                .iter()
501                .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
502                .collect();
503            Ok(Arc::new(result))
504        }
505        DataType::Int64 => {
506            let input = array.as_primitive::<Int64Type>();
507            let result: TimestampMillisecondArray = input
508                .iter()
509                .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
510                .collect();
511            Ok(Arc::new(result))
512        }
513        other => Err(DataFusionError::Plan(format!(
514            "Unsupported timestamp type for hop(): {other:?}. \
515             Use TimestampMillisecond or Int64."
516        ))),
517    }
518}
519
520/// Computes the earliest window start for a hopping window containing `ts`.
521///
522/// Windows of `size_ms` slide by `slide_ms`. The earliest window that
523/// contains `ts` starts at `floor((ts - size + slide) / slide) * slide`.
524#[inline]
525fn hop_earliest_start(ts: i64, slide_ms: i64, size_ms: i64) -> i64 {
526    let adjusted = ts - size_ms + slide_ms;
527    adjusted - adjusted.rem_euclid(slide_ms)
528}
529
530/// Converts a timestamp array to `TimestampMillisecond` for consistent output.
531fn convert_to_timestamp_ms_array(array: &ArrayRef) -> Result<ArrayRef> {
532    match array.data_type() {
533        DataType::Timestamp(TimeUnit::Millisecond, _) => Ok(Arc::clone(array)),
534        DataType::Int64 => {
535            let input = array.as_primitive::<Int64Type>();
536            let result: TimestampMillisecondArray = input.iter().collect();
537            Ok(Arc::new(result))
538        }
539        other => Err(DataFusionError::Plan(format!(
540            "Unsupported timestamp type for session(): {other:?}. \
541             Use TimestampMillisecond or Int64."
542        ))),
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549    use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
550    use arrow_array::Array;
551    use arrow_schema::Field;
552    use datafusion_common::config::ConfigOptions;
553    use datafusion_expr::ScalarUDF;
554
555    fn interval_dt(days: i32, ms: i32) -> ColumnarValue {
556        ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(
557            days, ms,
558        ))))
559    }
560
561    fn ts_ms(ms: Option<i64>) -> ColumnarValue {
562        ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(ms, None))
563    }
564
565    fn expect_ts_ms(result: ColumnarValue) -> Option<i64> {
566        match result {
567            ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, _)) => v,
568            other => panic!("Expected TimestampMillisecond scalar, got: {other:?}"),
569        }
570    }
571
572    fn make_args(args: Vec<ColumnarValue>, rows: usize) -> ScalarFunctionArgs {
573        ScalarFunctionArgs {
574            args,
575            arg_fields: vec![],
576            number_rows: rows,
577            return_field: Arc::new(Field::new(
578                "output",
579                DataType::Timestamp(TimeUnit::Millisecond, None),
580                true,
581            )),
582            config_options: Arc::new(ConfigOptions::default()),
583        }
584    }
585
586    // ── Tumble tests ─────────────────────────────────────────────────────
587
588    #[test]
589    fn test_tumble_basic() {
590        let udf = TumbleWindowStart::new();
591        // 5-minute interval = 300_000 ms, timestamp at 7 min
592        let result = udf
593            .invoke_with_args(make_args(
594                vec![ts_ms(Some(420_000)), interval_dt(0, 300_000)],
595                1,
596            ))
597            .unwrap();
598        assert_eq!(expect_ts_ms(result), Some(300_000));
599    }
600
601    #[test]
602    fn test_tumble_exact_boundary() {
603        let udf = TumbleWindowStart::new();
604        let result = udf
605            .invoke_with_args(make_args(
606                vec![ts_ms(Some(300_000)), interval_dt(0, 300_000)],
607                1,
608            ))
609            .unwrap();
610        assert_eq!(expect_ts_ms(result), Some(300_000));
611    }
612
613    #[test]
614    fn test_tumble_zero_timestamp() {
615        let udf = TumbleWindowStart::new();
616        let result = udf
617            .invoke_with_args(make_args(vec![ts_ms(Some(0)), interval_dt(0, 300_000)], 1))
618            .unwrap();
619        assert_eq!(expect_ts_ms(result), Some(0));
620    }
621
622    #[test]
623    fn test_tumble_null_handling() {
624        let udf = TumbleWindowStart::new();
625        let result = udf
626            .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 300_000)], 1))
627            .unwrap();
628        assert_eq!(expect_ts_ms(result), None);
629    }
630
631    #[test]
632    fn test_tumble_array_input() {
633        let udf = TumbleWindowStart::new();
634        let ts_array = TimestampMillisecondArray::from(vec![
635            Some(0),
636            Some(150_000),
637            Some(300_000),
638            Some(420_000),
639            None,
640        ]);
641        let ts = ColumnarValue::Array(Arc::new(ts_array));
642        let interval = interval_dt(0, 300_000);
643
644        let result = udf
645            .invoke_with_args(make_args(vec![ts, interval], 5))
646            .unwrap();
647        match result {
648            ColumnarValue::Array(arr) => {
649                let r = arr.as_primitive::<TimestampMillisecondType>();
650                assert_eq!(r.value(0), 0);
651                assert_eq!(r.value(1), 0);
652                assert_eq!(r.value(2), 300_000);
653                assert_eq!(r.value(3), 300_000);
654                assert!(r.is_null(4));
655            }
656            ColumnarValue::Scalar(_) => panic!("Expected array result"),
657        }
658    }
659
660    #[test]
661    fn test_tumble_month_day_nano_interval() {
662        let udf = TumbleWindowStart::new();
663        // 1 hour = 3_600_000_000_000 nanoseconds
664        let interval = ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(
665            IntervalMonthDayNano::new(0, 0, 3_600_000_000_000),
666        )));
667        // 90 minutes = 5_400_000 ms
668        let result = udf
669            .invoke_with_args(make_args(vec![ts_ms(Some(5_400_000)), interval], 1))
670            .unwrap();
671        assert_eq!(expect_ts_ms(result), Some(3_600_000));
672    }
673
674    #[test]
675    fn test_tumble_rejects_zero_interval() {
676        let udf = TumbleWindowStart::new();
677        let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000)), interval_dt(0, 0)], 1));
678        assert!(result.is_err());
679    }
680
681    #[test]
682    fn test_tumble_rejects_wrong_arg_count() {
683        let udf = TumbleWindowStart::new();
684        let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000))], 1));
685        assert!(result.is_err());
686    }
687
688    // ── Hop tests ────────────────────────────────────────────────────────
689
690    #[test]
691    fn test_hop_basic() {
692        let udf = HopWindowStart::new();
693        // slide=5min, size=10min, ts=7min
694        let result = udf
695            .invoke_with_args(make_args(
696                vec![
697                    ts_ms(Some(420_000)),
698                    interval_dt(0, 300_000),
699                    interval_dt(0, 600_000),
700                ],
701                1,
702            ))
703            .unwrap();
704        // Earliest 10-min window (sliding 5min) containing 420_000:
705        // adjusted = 420_000 - 600_000 + 300_000 = 120_000
706        // 120_000 - (120_000 % 300_000) = 120_000 - 120_000 = 0
707        assert_eq!(expect_ts_ms(result), Some(0));
708    }
709
710    #[test]
711    fn test_hop_at_boundary() {
712        let udf = HopWindowStart::new();
713        // slide=5min, size=10min, ts=exactly 5min
714        let result = udf
715            .invoke_with_args(make_args(
716                vec![
717                    ts_ms(Some(300_000)),
718                    interval_dt(0, 300_000),
719                    interval_dt(0, 600_000),
720                ],
721                1,
722            ))
723            .unwrap();
724        // adjusted = 300_000 - 600_000 + 300_000 = 0
725        // 0 - (0 % 300_000) = 0
726        assert_eq!(expect_ts_ms(result), Some(0));
727    }
728
729    #[test]
730    fn test_hop_rejects_wrong_arg_count() {
731        let udf = HopWindowStart::new();
732        let result = udf.invoke_with_args(make_args(
733            vec![ts_ms(Some(1000)), interval_dt(0, 300_000)],
734            1,
735        ));
736        assert!(result.is_err());
737    }
738
739    // ── Session tests ────────────────────────────────────────────────────
740
741    #[test]
742    fn test_session_passthrough_scalar() {
743        let udf = SessionWindowStart::new();
744        let result = udf
745            .invoke_with_args(make_args(
746                vec![ts_ms(Some(42_000)), interval_dt(0, 60_000)],
747                1,
748            ))
749            .unwrap();
750        assert_eq!(expect_ts_ms(result), Some(42_000));
751    }
752
753    #[test]
754    fn test_session_passthrough_null() {
755        let udf = SessionWindowStart::new();
756        let result = udf
757            .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 60_000)], 1))
758            .unwrap();
759        assert_eq!(expect_ts_ms(result), None);
760    }
761
762    // ── Registration & signature tests ───────────────────────────────────
763
764    // ── Cumulate tests ──────────────────────────────────────────────────
765
766    #[test]
767    fn test_cumulate_basic() {
768        let udf = CumulateWindowStart::new();
769        // step=1min, size=5min, ts=30s → epoch start = 0
770        let result = udf
771            .invoke_with_args(make_args(
772                vec![
773                    ts_ms(Some(30_000)),
774                    interval_dt(0, 60_000),
775                    interval_dt(0, 300_000),
776                ],
777                1,
778            ))
779            .unwrap();
780        assert_eq!(expect_ts_ms(result), Some(0));
781    }
782
783    #[test]
784    fn test_cumulate_second_epoch() {
785        let udf = CumulateWindowStart::new();
786        // step=1min, size=5min, ts=350s → epoch start = 300_000
787        let result = udf
788            .invoke_with_args(make_args(
789                vec![
790                    ts_ms(Some(350_000)),
791                    interval_dt(0, 60_000),
792                    interval_dt(0, 300_000),
793                ],
794                1,
795            ))
796            .unwrap();
797        assert_eq!(expect_ts_ms(result), Some(300_000));
798    }
799
800    #[test]
801    fn test_cumulate_rejects_step_exceeds_size() {
802        let udf = CumulateWindowStart::new();
803        let result = udf.invoke_with_args(make_args(
804            vec![
805                ts_ms(Some(1000)),
806                interval_dt(0, 600_000),
807                interval_dt(0, 300_000),
808            ],
809            1,
810        ));
811        assert!(result.is_err());
812    }
813
814    #[test]
815    fn test_cumulate_rejects_not_divisible() {
816        let udf = CumulateWindowStart::new();
817        let result = udf.invoke_with_args(make_args(
818            vec![
819                ts_ms(Some(1000)),
820                interval_dt(0, 70_000),
821                interval_dt(0, 300_000),
822            ],
823            1,
824        ));
825        assert!(result.is_err());
826    }
827
828    #[test]
829    fn test_cumulate_rejects_wrong_arg_count() {
830        let udf = CumulateWindowStart::new();
831        let result = udf.invoke_with_args(make_args(
832            vec![ts_ms(Some(1000)), interval_dt(0, 60_000)],
833            1,
834        ));
835        assert!(result.is_err());
836    }
837
838    // ── Registration & signature tests ───────────────────────────────────
839
840    #[test]
841    fn test_udf_registration() {
842        let tumble = ScalarUDF::new_from_impl(TumbleWindowStart::new());
843        assert_eq!(tumble.name(), "tumble");
844
845        let hop = ScalarUDF::new_from_impl(HopWindowStart::new());
846        assert_eq!(hop.name(), "hop");
847
848        let session = ScalarUDF::new_from_impl(SessionWindowStart::new());
849        assert_eq!(session.name(), "session");
850
851        let cumulate = ScalarUDF::new_from_impl(CumulateWindowStart::new());
852        assert_eq!(cumulate.name(), "cumulate");
853    }
854
855    #[test]
856    fn test_udf_signatures_immutable() {
857        assert_eq!(
858            TumbleWindowStart::new().signature().volatility,
859            Volatility::Immutable
860        );
861        assert_eq!(
862            HopWindowStart::new().signature().volatility,
863            Volatility::Immutable
864        );
865        assert_eq!(
866            SessionWindowStart::new().signature().volatility,
867            Volatility::Immutable
868        );
869        assert_eq!(
870            CumulateWindowStart::new().signature().volatility,
871            Volatility::Immutable
872        );
873    }
874
875    #[test]
876    fn test_tumble_return_type() {
877        let udf = TumbleWindowStart::new();
878        let rt = udf.return_type(&[]).unwrap();
879        assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
880    }
881
882    #[test]
883    fn test_cumulate_return_type() {
884        let udf = CumulateWindowStart::new();
885        let rt = udf.return_type(&[]).unwrap();
886        assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
887    }
888}