Skip to main content

laminar_sql/datafusion/
window_udf.rs

1//! Window function UDFs for `DataFusion` integration
2//!
3//! Provides scalar UDFs that compute window start timestamps for
4//! streaming window operations:
5//!
6//! - [`TumbleWindowStart`] — `tumble(timestamp, interval)` — fixed-size non-overlapping windows
7//! - [`HopWindowStart`] — `hop(timestamp, slide, size)` — fixed-size overlapping windows
8//! - [`SessionWindowStart`] — `session(timestamp, gap)` — pass-through for Ring 0 sessions
9//!
10//! These UDFs allow `DataFusion` to execute `GROUP BY TUMBLE(...)` style queries
11//! by computing the window start as a per-row scalar value.
12
13use std::any::Any;
14use std::hash::{Hash, Hasher};
15use std::sync::Arc;
16
17use arrow::datatypes::{DataType, Int64Type, TimeUnit, TimestampMillisecondType};
18use arrow_array::cast::AsArray;
19use arrow_array::{ArrayRef, TimestampMillisecondArray};
20use datafusion_common::{DataFusionError, Result, ScalarValue};
21use datafusion_expr::{
22    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
23};
24
25// ─── TumbleWindowStart ───────────────────────────────────────────────────────
26
27/// Computes the tumbling window start for a given timestamp.
28///
29/// `tumble(timestamp, interval)` returns `floor(ts / interval) * interval`,
30/// which is the start of the non-overlapping window that contains `ts`.
31///
32/// # Arguments
33///
34/// * Arg 0: Timestamp column or scalar (`TimestampMillisecond` or `Int64` ms)
35/// * Arg 1: Window size as an interval scalar
36///
37/// # Returns
38///
39/// `TimestampMillisecond` representing the window start.
40#[derive(Debug)]
41pub struct TumbleWindowStart {
42    signature: Signature,
43}
44
45impl TumbleWindowStart {
46    /// Creates a new tumble window start UDF.
47    #[must_use]
48    pub fn new() -> Self {
49        Self {
50            signature: Signature::new(
51                TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]),
52                Volatility::Immutable,
53            ),
54        }
55    }
56}
57
58impl Default for TumbleWindowStart {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl PartialEq for TumbleWindowStart {
65    fn eq(&self, _other: &Self) -> bool {
66        true // All instances are identical
67    }
68}
69
70impl Eq for TumbleWindowStart {}
71
72impl Hash for TumbleWindowStart {
73    fn hash<H: Hasher>(&self, state: &mut H) {
74        "tumble".hash(state);
75    }
76}
77
78impl ScalarUDFImpl for TumbleWindowStart {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn name(&self) -> &'static str {
84        "tumble"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
92        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
93    }
94
95    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96        let ScalarFunctionArgs { args, .. } = args;
97        if args.len() < 2 || args.len() > 3 {
98            return Err(DataFusionError::Plan(
99                "tumble() requires 2-3 arguments: (timestamp, interval [, offset])".to_string(),
100            ));
101        }
102        let interval_ms = extract_interval_ms(&args[1])?;
103        if interval_ms <= 0 {
104            return Err(DataFusionError::Plan(
105                "tumble() interval must be positive".to_string(),
106            ));
107        }
108        let offset_ms = if args.len() == 3 {
109            extract_interval_ms(&args[2])?
110        } else {
111            0
112        };
113        compute_tumble_with_offset(&args[0], interval_ms, offset_ms)
114    }
115}
116
117// ─── HopWindowStart ──────────────────────────────────────────────────────────
118
119/// Computes the earliest hopping window start for a given timestamp.
120///
121/// `hop(timestamp, slide, size)` returns the start of the earliest window
122/// (of the given `size`, sliding by `slide`) that contains `ts`.
123///
124/// # Limitation
125///
126/// This returns only the *earliest* window start. Full multi-window
127/// assignment (one row per window) is handled by Ring 0 operators.
128///
129/// # Arguments
130///
131/// * Arg 0: Timestamp column or scalar
132/// * Arg 1: Slide interval scalar
133/// * Arg 2: Window size interval scalar
134#[derive(Debug)]
135pub struct HopWindowStart {
136    signature: Signature,
137}
138
139impl HopWindowStart {
140    /// Creates a new hop window start UDF.
141    #[must_use]
142    pub fn new() -> Self {
143        Self {
144            signature: Signature::new(
145                TypeSignature::OneOf(vec![TypeSignature::Any(3), TypeSignature::Any(4)]),
146                Volatility::Immutable,
147            ),
148        }
149    }
150}
151
152impl Default for HopWindowStart {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl PartialEq for HopWindowStart {
159    fn eq(&self, _other: &Self) -> bool {
160        true
161    }
162}
163
164impl Eq for HopWindowStart {}
165
166impl Hash for HopWindowStart {
167    fn hash<H: Hasher>(&self, state: &mut H) {
168        "hop".hash(state);
169    }
170}
171
172impl ScalarUDFImpl for HopWindowStart {
173    fn as_any(&self) -> &dyn Any {
174        self
175    }
176
177    fn name(&self) -> &'static str {
178        "hop"
179    }
180
181    fn signature(&self) -> &Signature {
182        &self.signature
183    }
184
185    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
186        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
187    }
188
189    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
190        let ScalarFunctionArgs { args, .. } = args;
191        if args.len() < 3 || args.len() > 4 {
192            return Err(DataFusionError::Plan(
193                "hop() requires 3-4 arguments: (timestamp, slide, size [, offset])".to_string(),
194            ));
195        }
196        let slide_ms = extract_interval_ms(&args[1])?;
197        let size_ms = extract_interval_ms(&args[2])?;
198        if slide_ms <= 0 || size_ms <= 0 {
199            return Err(DataFusionError::Plan(
200                "hop() slide and size must be positive".to_string(),
201            ));
202        }
203        let offset_ms = if args.len() == 4 {
204            extract_interval_ms(&args[3])?
205        } else {
206            0
207        };
208        compute_hop_with_offset(&args[0], slide_ms, size_ms, offset_ms)
209    }
210}
211
212// ─── SessionWindowStart ──────────────────────────────────────────────────────
213
214/// Pass-through UDF for session window compatibility.
215///
216/// `session(timestamp, gap)` returns the input timestamp unchanged.
217/// Session windows are data-dependent (gap-based grouping) and cannot
218/// be computed as a per-row scalar. The actual session assignment is
219/// handled by Ring 0 operators.
220///
221/// This UDF exists so that `GROUP BY SESSION(ts, gap)` is syntactically
222/// valid in `DataFusion` queries, with real session logic deferred to
223/// the streaming engine.
224#[derive(Debug)]
225pub struct SessionWindowStart {
226    signature: Signature,
227}
228
229impl SessionWindowStart {
230    /// Creates a new session window start UDF.
231    #[must_use]
232    pub fn new() -> Self {
233        Self {
234            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
235        }
236    }
237}
238
239impl Default for SessionWindowStart {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245impl PartialEq for SessionWindowStart {
246    fn eq(&self, _other: &Self) -> bool {
247        true
248    }
249}
250
251impl Eq for SessionWindowStart {}
252
253impl Hash for SessionWindowStart {
254    fn hash<H: Hasher>(&self, state: &mut H) {
255        "session".hash(state);
256    }
257}
258
259impl ScalarUDFImpl for SessionWindowStart {
260    fn as_any(&self) -> &dyn Any {
261        self
262    }
263
264    fn name(&self) -> &'static str {
265        "session"
266    }
267
268    fn signature(&self) -> &Signature {
269        &self.signature
270    }
271
272    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
273        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
274    }
275
276    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
277        let ScalarFunctionArgs { args, .. } = args;
278        if args.len() != 2 {
279            return Err(DataFusionError::Plan(
280                "session() requires exactly 2 arguments: (timestamp, gap)".to_string(),
281            ));
282        }
283        // Pass-through: return the input timestamp as-is
284        match &args[0] {
285            ColumnarValue::Array(array) => {
286                let result = convert_to_timestamp_ms_array(array)?;
287                Ok(ColumnarValue::Array(result))
288            }
289            ColumnarValue::Scalar(scalar) => {
290                let ts_ms = scalar_to_timestamp_ms(scalar)?;
291                Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
292                    ts_ms, None,
293                )))
294            }
295        }
296    }
297}
298
299// ─── CumulateWindowStart ────────────────────────────────────────────────────
300
301/// Computes the cumulate window epoch start for a given timestamp.
302///
303/// `cumulate(timestamp, step, size)` returns `floor(ts / size) * size`,
304/// which is the epoch start for the cumulating window that contains `ts`.
305/// The actual multi-window assignment (one row per cumulating window) is
306/// handled by Ring 0 operators.
307///
308/// # Arguments
309///
310/// * Arg 0: Timestamp column or scalar (`TimestampMillisecond` or `Int64` ms)
311/// * Arg 1: Step interval scalar (window growth increment)
312/// * Arg 2: Max size interval scalar (epoch size)
313#[derive(Debug)]
314pub struct CumulateWindowStart {
315    signature: Signature,
316}
317
318impl CumulateWindowStart {
319    /// Creates a new cumulate window start UDF.
320    #[must_use]
321    pub fn new() -> Self {
322        Self {
323            signature: Signature::new(TypeSignature::Any(3), Volatility::Immutable),
324        }
325    }
326}
327
328impl Default for CumulateWindowStart {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334impl PartialEq for CumulateWindowStart {
335    fn eq(&self, _other: &Self) -> bool {
336        true
337    }
338}
339
340impl Eq for CumulateWindowStart {}
341
342impl Hash for CumulateWindowStart {
343    fn hash<H: Hasher>(&self, state: &mut H) {
344        "cumulate".hash(state);
345    }
346}
347
348impl ScalarUDFImpl for CumulateWindowStart {
349    fn as_any(&self) -> &dyn Any {
350        self
351    }
352
353    fn name(&self) -> &'static str {
354        "cumulate"
355    }
356
357    fn signature(&self) -> &Signature {
358        &self.signature
359    }
360
361    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
362        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
363    }
364
365    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
366        let ScalarFunctionArgs { args, .. } = args;
367        if args.len() != 3 {
368            return Err(DataFusionError::Plan(
369                "cumulate() requires exactly 3 arguments: (timestamp, step, size)".to_string(),
370            ));
371        }
372        let step_ms = extract_interval_ms(&args[1])?;
373        let size_ms = extract_interval_ms(&args[2])?;
374        if step_ms <= 0 || size_ms <= 0 {
375            return Err(DataFusionError::Plan(
376                "cumulate() step and size must be positive".to_string(),
377            ));
378        }
379        if step_ms > size_ms {
380            return Err(DataFusionError::Plan(
381                "cumulate() step must not exceed size".to_string(),
382            ));
383        }
384        if size_ms % step_ms != 0 {
385            return Err(DataFusionError::Plan(
386                "cumulate() size must be evenly divisible by step".to_string(),
387            ));
388        }
389        // Return epoch start = floor(ts / size) * size (same as tumble with size)
390        compute_tumble(&args[0], size_ms)
391    }
392}
393
394// ─── Helper Functions ────────────────────────────────────────────────────────
395
396/// Extracts an interval value in milliseconds from a `ColumnarValue`.
397///
398/// Only scalar intervals are supported (array intervals would require
399/// per-row window sizes, which is not a valid streaming pattern).
400fn extract_interval_ms(value: &ColumnarValue) -> Result<i64> {
401    match value {
402        ColumnarValue::Scalar(scalar) => scalar_interval_to_ms(scalar),
403        ColumnarValue::Array(_) => Err(DataFusionError::NotImplemented(
404            "Array interval arguments not supported for window functions".to_string(),
405        )),
406    }
407}
408
409/// Converts a scalar interval to milliseconds.
410fn scalar_interval_to_ms(scalar: &ScalarValue) -> Result<i64> {
411    match scalar {
412        ScalarValue::IntervalDayTime(Some(v)) => {
413            Ok(i64::from(v.days) * 86_400_000 + i64::from(v.milliseconds))
414        }
415        ScalarValue::IntervalMonthDayNano(Some(v)) => {
416            if v.months != 0 {
417                return Err(DataFusionError::NotImplemented(
418                    "Month-based intervals not supported for window functions \
419                     (use days/hours/minutes/seconds)"
420                        .to_string(),
421                ));
422            }
423            Ok(i64::from(v.days) * 86_400_000 + v.nanoseconds / 1_000_000)
424        }
425        ScalarValue::IntervalYearMonth(_) => Err(DataFusionError::NotImplemented(
426            "Year-month intervals not supported for window functions".to_string(),
427        )),
428        ScalarValue::Int64(Some(ms)) => Ok(*ms),
429        _ => Err(DataFusionError::Plan(format!(
430            "Expected interval argument for window function, got: {scalar:?}"
431        ))),
432    }
433}
434
435/// Converts a scalar value to a timestamp in milliseconds.
436fn scalar_to_timestamp_ms(scalar: &ScalarValue) -> Result<Option<i64>> {
437    match scalar {
438        ScalarValue::TimestampMillisecond(v, _) | ScalarValue::Int64(v) => Ok(*v),
439        ScalarValue::TimestampMicrosecond(v, _) => Ok(v.map(|v| v / 1_000)),
440        ScalarValue::TimestampNanosecond(v, _) => Ok(v.map(|v| v / 1_000_000)),
441        ScalarValue::TimestampSecond(v, _) => Ok(v.map(|v| v * 1_000)),
442        _ => Err(DataFusionError::Plan(format!(
443            "Expected timestamp argument for window function, got: {scalar:?}"
444        ))),
445    }
446}
447
448/// Computes tumble window start for a `ColumnarValue`.
449fn compute_tumble(value: &ColumnarValue, interval_ms: i64) -> Result<ColumnarValue> {
450    match value {
451        ColumnarValue::Array(array) => {
452            let result = compute_tumble_array(array, interval_ms)?;
453            Ok(ColumnarValue::Array(result))
454        }
455        ColumnarValue::Scalar(scalar) => {
456            let ts_ms = scalar_to_timestamp_ms(scalar)?;
457            let window_start = ts_ms.map(|ts| ts - ts.rem_euclid(interval_ms));
458            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
459                window_start,
460                None,
461            )))
462        }
463    }
464}
465
466/// Computes tumble window start for an array of timestamps.
467fn compute_tumble_array(array: &ArrayRef, interval_ms: i64) -> Result<ArrayRef> {
468    match array.data_type() {
469        DataType::Timestamp(TimeUnit::Millisecond, _) => {
470            let input = array.as_primitive::<TimestampMillisecondType>();
471            let result: TimestampMillisecondArray = input
472                .iter()
473                .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
474                .collect();
475            Ok(Arc::new(result))
476        }
477        DataType::Int64 => {
478            let input = array.as_primitive::<Int64Type>();
479            let result: TimestampMillisecondArray = input
480                .iter()
481                .map(|opt_ts| opt_ts.map(|ts| ts - ts.rem_euclid(interval_ms)))
482                .collect();
483            Ok(Arc::new(result))
484        }
485        other => Err(DataFusionError::Plan(format!(
486            "Unsupported timestamp type for tumble(): {other:?}. \
487             Use TimestampMillisecond or Int64."
488        ))),
489    }
490}
491
492/// Computes tumble window start with offset for a `ColumnarValue`.
493fn compute_tumble_with_offset(
494    value: &ColumnarValue,
495    interval_ms: i64,
496    offset_ms: i64,
497) -> Result<ColumnarValue> {
498    if offset_ms == 0 {
499        return compute_tumble(value, interval_ms);
500    }
501    match value {
502        ColumnarValue::Array(array) => {
503            let result = compute_tumble_array_with_offset(array, interval_ms, offset_ms)?;
504            Ok(ColumnarValue::Array(result))
505        }
506        ColumnarValue::Scalar(scalar) => {
507            let ts_ms = scalar_to_timestamp_ms(scalar)?;
508            let window_start = ts_ms.map(|ts| {
509                let adj = ts - offset_ms;
510                (adj - adj.rem_euclid(interval_ms)) + offset_ms
511            });
512            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
513                window_start,
514                None,
515            )))
516        }
517    }
518}
519
520/// Computes tumble window start with offset for an array of timestamps.
521fn compute_tumble_array_with_offset(
522    array: &ArrayRef,
523    interval_ms: i64,
524    offset_ms: i64,
525) -> Result<ArrayRef> {
526    match array.data_type() {
527        DataType::Timestamp(TimeUnit::Millisecond, _) => {
528            let input = array.as_primitive::<TimestampMillisecondType>();
529            let result: TimestampMillisecondArray = input
530                .iter()
531                .map(|opt_ts| {
532                    opt_ts.map(|ts| {
533                        let adj = ts - offset_ms;
534                        (adj - adj.rem_euclid(interval_ms)) + offset_ms
535                    })
536                })
537                .collect();
538            Ok(Arc::new(result))
539        }
540        DataType::Int64 => {
541            let input = array.as_primitive::<Int64Type>();
542            let result: TimestampMillisecondArray = input
543                .iter()
544                .map(|opt_ts| {
545                    opt_ts.map(|ts| {
546                        let adj = ts - offset_ms;
547                        (adj - adj.rem_euclid(interval_ms)) + offset_ms
548                    })
549                })
550                .collect();
551            Ok(Arc::new(result))
552        }
553        other => Err(DataFusionError::Plan(format!(
554            "Unsupported timestamp type for tumble(): {other:?}. \
555             Use TimestampMillisecond or Int64."
556        ))),
557    }
558}
559
560/// Computes hop (earliest) window start for a `ColumnarValue`.
561fn compute_hop(value: &ColumnarValue, slide_ms: i64, size_ms: i64) -> Result<ColumnarValue> {
562    match value {
563        ColumnarValue::Array(array) => {
564            let result = compute_hop_array(array, slide_ms, size_ms)?;
565            Ok(ColumnarValue::Array(result))
566        }
567        ColumnarValue::Scalar(scalar) => {
568            let ts_ms = scalar_to_timestamp_ms(scalar)?;
569            let window_start = ts_ms.map(|ts| hop_earliest_start(ts, slide_ms, size_ms));
570            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
571                window_start,
572                None,
573            )))
574        }
575    }
576}
577
578/// Computes hop window start for an array of timestamps.
579fn compute_hop_array(array: &ArrayRef, slide_ms: i64, size_ms: i64) -> Result<ArrayRef> {
580    match array.data_type() {
581        DataType::Timestamp(TimeUnit::Millisecond, _) => {
582            let input = array.as_primitive::<TimestampMillisecondType>();
583            let result: TimestampMillisecondArray = input
584                .iter()
585                .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
586                .collect();
587            Ok(Arc::new(result))
588        }
589        DataType::Int64 => {
590            let input = array.as_primitive::<Int64Type>();
591            let result: TimestampMillisecondArray = input
592                .iter()
593                .map(|opt_ts| opt_ts.map(|ts| hop_earliest_start(ts, slide_ms, size_ms)))
594                .collect();
595            Ok(Arc::new(result))
596        }
597        other => Err(DataFusionError::Plan(format!(
598            "Unsupported timestamp type for hop(): {other:?}. \
599             Use TimestampMillisecond or Int64."
600        ))),
601    }
602}
603
604/// Computes the earliest window start for a hopping window containing `ts`.
605///
606/// Windows of `size_ms` slide by `slide_ms`. The earliest window that
607/// contains `ts` starts at `floor((ts - size + slide) / slide) * slide`.
608#[inline]
609fn hop_earliest_start(ts: i64, slide_ms: i64, size_ms: i64) -> i64 {
610    let adjusted = ts - size_ms + slide_ms;
611    adjusted - adjusted.rem_euclid(slide_ms)
612}
613
614/// Computes hop (earliest) window start with offset.
615fn compute_hop_with_offset(
616    value: &ColumnarValue,
617    slide_ms: i64,
618    size_ms: i64,
619    offset_ms: i64,
620) -> Result<ColumnarValue> {
621    if offset_ms == 0 {
622        return compute_hop(value, slide_ms, size_ms);
623    }
624    match value {
625        ColumnarValue::Array(array) => {
626            let result = compute_hop_array_with_offset(array, slide_ms, size_ms, offset_ms)?;
627            Ok(ColumnarValue::Array(result))
628        }
629        ColumnarValue::Scalar(scalar) => {
630            let ts_ms = scalar_to_timestamp_ms(scalar)?;
631            let window_start =
632                ts_ms.map(|ts| hop_earliest_start(ts - offset_ms, slide_ms, size_ms) + offset_ms);
633            Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
634                window_start,
635                None,
636            )))
637        }
638    }
639}
640
641/// Computes hop window start with offset for an array.
642fn compute_hop_array_with_offset(
643    array: &ArrayRef,
644    slide_ms: i64,
645    size_ms: i64,
646    offset_ms: i64,
647) -> Result<ArrayRef> {
648    match array.data_type() {
649        DataType::Timestamp(TimeUnit::Millisecond, _) => {
650            let input = array.as_primitive::<TimestampMillisecondType>();
651            let result: TimestampMillisecondArray = input
652                .iter()
653                .map(|opt_ts| {
654                    opt_ts
655                        .map(|ts| hop_earliest_start(ts - offset_ms, slide_ms, size_ms) + offset_ms)
656                })
657                .collect();
658            Ok(Arc::new(result))
659        }
660        DataType::Int64 => {
661            let input = array.as_primitive::<Int64Type>();
662            let result: TimestampMillisecondArray = input
663                .iter()
664                .map(|opt_ts| {
665                    opt_ts
666                        .map(|ts| hop_earliest_start(ts - offset_ms, slide_ms, size_ms) + offset_ms)
667                })
668                .collect();
669            Ok(Arc::new(result))
670        }
671        other => Err(DataFusionError::Plan(format!(
672            "Unsupported timestamp type for hop(): {other:?}. \
673             Use TimestampMillisecond or Int64."
674        ))),
675    }
676}
677
678/// Converts a timestamp array to `TimestampMillisecond` for consistent output.
679fn convert_to_timestamp_ms_array(array: &ArrayRef) -> Result<ArrayRef> {
680    match array.data_type() {
681        DataType::Timestamp(TimeUnit::Millisecond, _) => Ok(Arc::clone(array)),
682        DataType::Int64 => {
683            let input = array.as_primitive::<Int64Type>();
684            let result: TimestampMillisecondArray = input.iter().collect();
685            Ok(Arc::new(result))
686        }
687        other => Err(DataFusionError::Plan(format!(
688            "Unsupported timestamp type for session(): {other:?}. \
689             Use TimestampMillisecond or Int64."
690        ))),
691    }
692}
693
694#[cfg(test)]
695mod tests {
696    use super::*;
697    use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
698    use arrow_array::Array;
699    use arrow_schema::Field;
700    use datafusion_common::config::ConfigOptions;
701    use datafusion_expr::ScalarUDF;
702
703    fn interval_dt(days: i32, ms: i32) -> ColumnarValue {
704        ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new(
705            days, ms,
706        ))))
707    }
708
709    fn ts_ms(ms: Option<i64>) -> ColumnarValue {
710        ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(ms, None))
711    }
712
713    fn expect_ts_ms(result: ColumnarValue) -> Option<i64> {
714        match result {
715            ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, _)) => v,
716            other => panic!("Expected TimestampMillisecond scalar, got: {other:?}"),
717        }
718    }
719
720    fn make_args(args: Vec<ColumnarValue>, rows: usize) -> ScalarFunctionArgs {
721        ScalarFunctionArgs {
722            args,
723            arg_fields: vec![],
724            number_rows: rows,
725            return_field: Arc::new(Field::new(
726                "output",
727                DataType::Timestamp(TimeUnit::Millisecond, None),
728                true,
729            )),
730            config_options: Arc::new(ConfigOptions::default()),
731        }
732    }
733
734    // ── Tumble tests ─────────────────────────────────────────────────────
735
736    #[test]
737    fn test_tumble_basic() {
738        let udf = TumbleWindowStart::new();
739        // 5-minute interval = 300_000 ms, timestamp at 7 min
740        let result = udf
741            .invoke_with_args(make_args(
742                vec![ts_ms(Some(420_000)), interval_dt(0, 300_000)],
743                1,
744            ))
745            .unwrap();
746        assert_eq!(expect_ts_ms(result), Some(300_000));
747    }
748
749    #[test]
750    fn test_tumble_exact_boundary() {
751        let udf = TumbleWindowStart::new();
752        let result = udf
753            .invoke_with_args(make_args(
754                vec![ts_ms(Some(300_000)), interval_dt(0, 300_000)],
755                1,
756            ))
757            .unwrap();
758        assert_eq!(expect_ts_ms(result), Some(300_000));
759    }
760
761    #[test]
762    fn test_tumble_zero_timestamp() {
763        let udf = TumbleWindowStart::new();
764        let result = udf
765            .invoke_with_args(make_args(vec![ts_ms(Some(0)), interval_dt(0, 300_000)], 1))
766            .unwrap();
767        assert_eq!(expect_ts_ms(result), Some(0));
768    }
769
770    #[test]
771    fn test_tumble_null_handling() {
772        let udf = TumbleWindowStart::new();
773        let result = udf
774            .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 300_000)], 1))
775            .unwrap();
776        assert_eq!(expect_ts_ms(result), None);
777    }
778
779    #[test]
780    fn test_tumble_array_input() {
781        let udf = TumbleWindowStart::new();
782        let ts_array = TimestampMillisecondArray::from(vec![
783            Some(0),
784            Some(150_000),
785            Some(300_000),
786            Some(420_000),
787            None,
788        ]);
789        let ts = ColumnarValue::Array(Arc::new(ts_array));
790        let interval = interval_dt(0, 300_000);
791
792        let result = udf
793            .invoke_with_args(make_args(vec![ts, interval], 5))
794            .unwrap();
795        match result {
796            ColumnarValue::Array(arr) => {
797                let r = arr.as_primitive::<TimestampMillisecondType>();
798                assert_eq!(r.value(0), 0);
799                assert_eq!(r.value(1), 0);
800                assert_eq!(r.value(2), 300_000);
801                assert_eq!(r.value(3), 300_000);
802                assert!(r.is_null(4));
803            }
804            ColumnarValue::Scalar(_) => panic!("Expected array result"),
805        }
806    }
807
808    #[test]
809    fn test_tumble_month_day_nano_interval() {
810        let udf = TumbleWindowStart::new();
811        // 1 hour = 3_600_000_000_000 nanoseconds
812        let interval = ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(
813            IntervalMonthDayNano::new(0, 0, 3_600_000_000_000),
814        )));
815        // 90 minutes = 5_400_000 ms
816        let result = udf
817            .invoke_with_args(make_args(vec![ts_ms(Some(5_400_000)), interval], 1))
818            .unwrap();
819        assert_eq!(expect_ts_ms(result), Some(3_600_000));
820    }
821
822    #[test]
823    fn test_tumble_rejects_zero_interval() {
824        let udf = TumbleWindowStart::new();
825        let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000)), interval_dt(0, 0)], 1));
826        assert!(result.is_err());
827    }
828
829    #[test]
830    fn test_tumble_rejects_wrong_arg_count() {
831        let udf = TumbleWindowStart::new();
832        let result = udf.invoke_with_args(make_args(vec![ts_ms(Some(1000))], 1));
833        assert!(result.is_err());
834    }
835
836    // ── Hop tests ────────────────────────────────────────────────────────
837
838    #[test]
839    fn test_hop_basic() {
840        let udf = HopWindowStart::new();
841        // slide=5min, size=10min, ts=7min
842        let result = udf
843            .invoke_with_args(make_args(
844                vec![
845                    ts_ms(Some(420_000)),
846                    interval_dt(0, 300_000),
847                    interval_dt(0, 600_000),
848                ],
849                1,
850            ))
851            .unwrap();
852        // Earliest 10-min window (sliding 5min) containing 420_000:
853        // adjusted = 420_000 - 600_000 + 300_000 = 120_000
854        // 120_000 - (120_000 % 300_000) = 120_000 - 120_000 = 0
855        assert_eq!(expect_ts_ms(result), Some(0));
856    }
857
858    #[test]
859    fn test_hop_at_boundary() {
860        let udf = HopWindowStart::new();
861        // slide=5min, size=10min, ts=exactly 5min
862        let result = udf
863            .invoke_with_args(make_args(
864                vec![
865                    ts_ms(Some(300_000)),
866                    interval_dt(0, 300_000),
867                    interval_dt(0, 600_000),
868                ],
869                1,
870            ))
871            .unwrap();
872        // adjusted = 300_000 - 600_000 + 300_000 = 0
873        // 0 - (0 % 300_000) = 0
874        assert_eq!(expect_ts_ms(result), Some(0));
875    }
876
877    #[test]
878    fn test_hop_rejects_wrong_arg_count() {
879        let udf = HopWindowStart::new();
880        let result = udf.invoke_with_args(make_args(
881            vec![ts_ms(Some(1000)), interval_dt(0, 300_000)],
882            1,
883        ));
884        assert!(result.is_err());
885    }
886
887    // ── Session tests ────────────────────────────────────────────────────
888
889    #[test]
890    fn test_session_passthrough_scalar() {
891        let udf = SessionWindowStart::new();
892        let result = udf
893            .invoke_with_args(make_args(
894                vec![ts_ms(Some(42_000)), interval_dt(0, 60_000)],
895                1,
896            ))
897            .unwrap();
898        assert_eq!(expect_ts_ms(result), Some(42_000));
899    }
900
901    #[test]
902    fn test_session_passthrough_null() {
903        let udf = SessionWindowStart::new();
904        let result = udf
905            .invoke_with_args(make_args(vec![ts_ms(None), interval_dt(0, 60_000)], 1))
906            .unwrap();
907        assert_eq!(expect_ts_ms(result), None);
908    }
909
910    // ── Registration & signature tests ───────────────────────────────────
911
912    // ── Cumulate tests ──────────────────────────────────────────────────
913
914    #[test]
915    fn test_cumulate_basic() {
916        let udf = CumulateWindowStart::new();
917        // step=1min, size=5min, ts=30s → epoch start = 0
918        let result = udf
919            .invoke_with_args(make_args(
920                vec![
921                    ts_ms(Some(30_000)),
922                    interval_dt(0, 60_000),
923                    interval_dt(0, 300_000),
924                ],
925                1,
926            ))
927            .unwrap();
928        assert_eq!(expect_ts_ms(result), Some(0));
929    }
930
931    #[test]
932    fn test_cumulate_second_epoch() {
933        let udf = CumulateWindowStart::new();
934        // step=1min, size=5min, ts=350s → epoch start = 300_000
935        let result = udf
936            .invoke_with_args(make_args(
937                vec![
938                    ts_ms(Some(350_000)),
939                    interval_dt(0, 60_000),
940                    interval_dt(0, 300_000),
941                ],
942                1,
943            ))
944            .unwrap();
945        assert_eq!(expect_ts_ms(result), Some(300_000));
946    }
947
948    #[test]
949    fn test_cumulate_rejects_step_exceeds_size() {
950        let udf = CumulateWindowStart::new();
951        let result = udf.invoke_with_args(make_args(
952            vec![
953                ts_ms(Some(1000)),
954                interval_dt(0, 600_000),
955                interval_dt(0, 300_000),
956            ],
957            1,
958        ));
959        assert!(result.is_err());
960    }
961
962    #[test]
963    fn test_cumulate_rejects_not_divisible() {
964        let udf = CumulateWindowStart::new();
965        let result = udf.invoke_with_args(make_args(
966            vec![
967                ts_ms(Some(1000)),
968                interval_dt(0, 70_000),
969                interval_dt(0, 300_000),
970            ],
971            1,
972        ));
973        assert!(result.is_err());
974    }
975
976    #[test]
977    fn test_cumulate_rejects_wrong_arg_count() {
978        let udf = CumulateWindowStart::new();
979        let result = udf.invoke_with_args(make_args(
980            vec![ts_ms(Some(1000)), interval_dt(0, 60_000)],
981            1,
982        ));
983        assert!(result.is_err());
984    }
985
986    // ── Registration & signature tests ───────────────────────────────────
987
988    #[test]
989    fn test_udf_registration() {
990        let tumble = ScalarUDF::new_from_impl(TumbleWindowStart::new());
991        assert_eq!(tumble.name(), "tumble");
992
993        let hop = ScalarUDF::new_from_impl(HopWindowStart::new());
994        assert_eq!(hop.name(), "hop");
995
996        let session = ScalarUDF::new_from_impl(SessionWindowStart::new());
997        assert_eq!(session.name(), "session");
998
999        let cumulate = ScalarUDF::new_from_impl(CumulateWindowStart::new());
1000        assert_eq!(cumulate.name(), "cumulate");
1001    }
1002
1003    #[test]
1004    fn test_udf_signatures_immutable() {
1005        assert_eq!(
1006            TumbleWindowStart::new().signature().volatility,
1007            Volatility::Immutable
1008        );
1009        assert_eq!(
1010            HopWindowStart::new().signature().volatility,
1011            Volatility::Immutable
1012        );
1013        assert_eq!(
1014            SessionWindowStart::new().signature().volatility,
1015            Volatility::Immutable
1016        );
1017        assert_eq!(
1018            CumulateWindowStart::new().signature().volatility,
1019            Volatility::Immutable
1020        );
1021    }
1022
1023    #[test]
1024    fn test_tumble_return_type() {
1025        let udf = TumbleWindowStart::new();
1026        let rt = udf.return_type(&[]).unwrap();
1027        assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
1028    }
1029
1030    #[test]
1031    fn test_cumulate_return_type() {
1032        let udf = CumulateWindowStart::new();
1033        let rt = udf.return_type(&[]).unwrap();
1034        assert_eq!(rt, DataType::Timestamp(TimeUnit::Millisecond, None));
1035    }
1036}