datafusion_python/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19
20use datafusion::functions_aggregate::all_default_aggregate_functions;
21use datafusion::functions_window::all_default_window_functions;
22use datafusion::logical_expr::expr::WindowFunctionParams;
23use datafusion::logical_expr::ExprFunctionExt;
24use datafusion::logical_expr::WindowFrame;
25use pyo3::{prelude::*, wrap_pyfunction};
26
27use crate::common::data_type::NullTreatment;
28use crate::common::data_type::PyScalarValue;
29use crate::context::PySessionContext;
30use crate::errors::PyDataFusionError;
31use crate::errors::PyDataFusionResult;
32use crate::expr::conditional_expr::PyCaseBuilder;
33use crate::expr::sort_expr::to_sort_expressions;
34use crate::expr::sort_expr::PySortExpr;
35use crate::expr::window::PyWindowFrame;
36use crate::expr::PyExpr;
37use datafusion::common::{Column, ScalarValue, TableReference};
38use datafusion::execution::FunctionRegistry;
39use datafusion::functions;
40use datafusion::functions_aggregate;
41use datafusion::functions_window;
42use datafusion::logical_expr::expr::Alias;
43use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment;
44use datafusion::logical_expr::{expr::WindowFunction, lit, Expr, WindowFunctionDefinition};
45
46fn add_builder_fns_to_aggregate(
47    agg_fn: Expr,
48    distinct: Option<bool>,
49    filter: Option<PyExpr>,
50    order_by: Option<Vec<PySortExpr>>,
51    null_treatment: Option<NullTreatment>,
52) -> PyDataFusionResult<PyExpr> {
53    // Since ExprFuncBuilder::new() is private, we can guarantee initializing
54    // a builder with an `null_treatment` with option None
55    let mut builder = agg_fn.null_treatment(None);
56
57    if let Some(order_by_cols) = order_by {
58        let order_by_cols = to_sort_expressions(order_by_cols);
59        builder = builder.order_by(order_by_cols);
60    }
61
62    if let Some(true) = distinct {
63        builder = builder.distinct();
64    }
65
66    if let Some(filter) = filter {
67        builder = builder.filter(filter.expr);
68    }
69
70    builder = builder.null_treatment(null_treatment.map(DFNullTreatment::from));
71
72    Ok(builder.build()?.into())
73}
74
75#[pyfunction]
76fn in_list(expr: PyExpr, value: Vec<PyExpr>, negated: bool) -> PyExpr {
77    datafusion::logical_expr::in_list(
78        expr.expr,
79        value.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
80        negated,
81    )
82    .into()
83}
84
85#[pyfunction]
86fn make_array(exprs: Vec<PyExpr>) -> PyExpr {
87    datafusion::functions_nested::expr_fn::make_array(exprs.into_iter().map(|x| x.into()).collect())
88        .into()
89}
90
91#[pyfunction]
92fn array_concat(exprs: Vec<PyExpr>) -> PyExpr {
93    let exprs = exprs.into_iter().map(|x| x.into()).collect();
94    datafusion::functions_nested::expr_fn::array_concat(exprs).into()
95}
96
97#[pyfunction]
98fn array_cat(exprs: Vec<PyExpr>) -> PyExpr {
99    array_concat(exprs)
100}
101
102#[pyfunction]
103#[pyo3(signature = (array, element, index=None))]
104fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
105    let index = ScalarValue::Int64(index);
106    let index = Expr::Literal(index);
107    datafusion::functions_nested::expr_fn::array_position(array.into(), element.into(), index)
108        .into()
109}
110
111#[pyfunction]
112#[pyo3(signature = (array, begin, end, stride=None))]
113fn array_slice(array: PyExpr, begin: PyExpr, end: PyExpr, stride: Option<PyExpr>) -> PyExpr {
114    datafusion::functions_nested::expr_fn::array_slice(
115        array.into(),
116        begin.into(),
117        end.into(),
118        stride.map(Into::into),
119    )
120    .into()
121}
122
123/// Computes a binary hash of the given data. type is the algorithm to use.
124/// Standard algorithms are md5, sha224, sha256, sha384, sha512, blake2s, blake2b, and blake3.
125// #[pyfunction(value, method)]
126#[pyfunction]
127fn digest(value: PyExpr, method: PyExpr) -> PyExpr {
128    PyExpr {
129        expr: functions::expr_fn::digest(value.expr, method.expr),
130    }
131}
132
133/// Concatenates the text representations of all the arguments.
134/// NULL arguments are ignored.
135#[pyfunction]
136fn concat(args: Vec<PyExpr>) -> PyResult<PyExpr> {
137    let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
138    Ok(functions::string::expr_fn::concat(args).into())
139}
140
141/// Concatenates all but the first argument, with separators.
142/// The first argument is used as the separator string, and should not be NULL.
143/// Other NULL arguments are ignored.
144#[pyfunction]
145fn concat_ws(sep: String, args: Vec<PyExpr>) -> PyResult<PyExpr> {
146    let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
147    Ok(functions::string::expr_fn::concat_ws(lit(sep), args).into())
148}
149
150#[pyfunction]
151#[pyo3(signature = (values, regex, flags=None))]
152fn regexp_like(values: PyExpr, regex: PyExpr, flags: Option<PyExpr>) -> PyResult<PyExpr> {
153    Ok(functions::expr_fn::regexp_like(values.expr, regex.expr, flags.map(|x| x.expr)).into())
154}
155
156#[pyfunction]
157#[pyo3(signature = (values, regex, flags=None))]
158fn regexp_match(values: PyExpr, regex: PyExpr, flags: Option<PyExpr>) -> PyResult<PyExpr> {
159    Ok(functions::expr_fn::regexp_match(values.expr, regex.expr, flags.map(|x| x.expr)).into())
160}
161
162#[pyfunction]
163#[pyo3(signature = (string, pattern, replacement, flags=None))]
164/// Replaces substring(s) matching a POSIX regular expression.
165fn regexp_replace(
166    string: PyExpr,
167    pattern: PyExpr,
168    replacement: PyExpr,
169    flags: Option<PyExpr>,
170) -> PyResult<PyExpr> {
171    Ok(functions::expr_fn::regexp_replace(
172        string.into(),
173        pattern.into(),
174        replacement.into(),
175        flags.map(|x| x.expr),
176    )
177    .into())
178}
179
180#[pyfunction]
181#[pyo3(signature = (string, pattern, start, flags=None))]
182/// Returns the number of matches found in the string.
183fn regexp_count(
184    string: PyExpr,
185    pattern: PyExpr,
186    start: Option<PyExpr>,
187    flags: Option<PyExpr>,
188) -> PyResult<PyExpr> {
189    Ok(functions::expr_fn::regexp_count(
190        string.expr,
191        pattern.expr,
192        start.map(|x| x.expr),
193        flags.map(|x| x.expr),
194    )
195    .into())
196}
197
198/// Creates a new Sort Expr
199#[pyfunction]
200fn order_by(expr: PyExpr, asc: bool, nulls_first: bool) -> PyResult<PySortExpr> {
201    Ok(PySortExpr::from(datafusion::logical_expr::expr::Sort {
202        expr: expr.expr,
203        asc,
204        nulls_first,
205    }))
206}
207
208/// Creates a new Alias Expr
209#[pyfunction]
210#[pyo3(signature = (expr, name, metadata=None))]
211fn alias(expr: PyExpr, name: &str, metadata: Option<HashMap<String, String>>) -> PyResult<PyExpr> {
212    let relation: Option<TableReference> = None;
213    Ok(PyExpr {
214        expr: datafusion::logical_expr::Expr::Alias(
215            Alias::new(expr.expr, relation, name).with_metadata(metadata),
216        ),
217    })
218}
219
220/// Create a column reference Expr
221#[pyfunction]
222fn col(name: &str) -> PyResult<PyExpr> {
223    Ok(PyExpr {
224        expr: datafusion::logical_expr::Expr::Column(Column::new_unqualified(name)),
225    })
226}
227
228/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
229#[pyfunction]
230fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
231    Ok(PyCaseBuilder {
232        case_builder: datafusion::logical_expr::case(expr.expr),
233    })
234}
235
236/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
237#[pyfunction]
238fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
239    Ok(PyCaseBuilder {
240        case_builder: datafusion::logical_expr::when(when.expr, then.expr),
241    })
242}
243
244/// Helper function to find the appropriate window function.
245///
246/// Search procedure:
247/// 1) Search built in window functions, which are being deprecated.
248/// 1) If a session context is provided:
249///      1) search User Defined Aggregate Functions (UDAFs)
250///      1) search registered window functions
251///      1) search registered aggregate functions
252/// 1) If no function has been found, search default aggregate functions.
253///
254/// NOTE: we search the built-ins first because the `UDAF` versions currently do not have the same behavior.
255fn find_window_fn(
256    name: &str,
257    ctx: Option<PySessionContext>,
258) -> PyDataFusionResult<WindowFunctionDefinition> {
259    if let Some(ctx) = ctx {
260        // search UDAFs
261        let udaf = ctx
262            .ctx
263            .udaf(name)
264            .map(WindowFunctionDefinition::AggregateUDF)
265            .ok();
266
267        if let Some(udaf) = udaf {
268            return Ok(udaf);
269        }
270
271        let session_state = ctx.ctx.state();
272
273        // search registered window functions
274        let window_fn = session_state
275            .window_functions()
276            .get(name)
277            .map(|f| WindowFunctionDefinition::WindowUDF(f.clone()));
278
279        if let Some(window_fn) = window_fn {
280            return Ok(window_fn);
281        }
282
283        // search registered aggregate functions
284        let agg_fn = session_state
285            .aggregate_functions()
286            .get(name)
287            .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()));
288
289        if let Some(agg_fn) = agg_fn {
290            return Ok(agg_fn);
291        }
292    }
293
294    // search default aggregate functions
295    let agg_fn = all_default_aggregate_functions()
296        .iter()
297        .find(|v| v.name() == name || v.aliases().contains(&name.to_string()))
298        .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()));
299
300    if let Some(agg_fn) = agg_fn {
301        return Ok(agg_fn);
302    }
303
304    // search default window functions
305    let window_fn = all_default_window_functions()
306        .iter()
307        .find(|v| v.name() == name || v.aliases().contains(&name.to_string()))
308        .map(|f| WindowFunctionDefinition::WindowUDF(f.clone()));
309
310    if let Some(window_fn) = window_fn {
311        return Ok(window_fn);
312    }
313
314    Err(PyDataFusionError::Common(format!(
315        "window function `{name}` not found"
316    )))
317}
318
319/// Creates a new Window function expression
320#[pyfunction]
321#[pyo3(signature = (name, args, partition_by=None, order_by=None, window_frame=None, ctx=None))]
322fn window(
323    name: &str,
324    args: Vec<PyExpr>,
325    partition_by: Option<Vec<PyExpr>>,
326    order_by: Option<Vec<PySortExpr>>,
327    window_frame: Option<PyWindowFrame>,
328    ctx: Option<PySessionContext>,
329) -> PyResult<PyExpr> {
330    let fun = find_window_fn(name, ctx)?;
331
332    let window_frame = window_frame
333        .map(|w| w.into())
334        .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty())));
335
336    Ok(PyExpr {
337        expr: datafusion::logical_expr::Expr::WindowFunction(WindowFunction {
338            fun,
339            params: WindowFunctionParams {
340                args: args.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
341                partition_by: partition_by
342                    .unwrap_or_default()
343                    .into_iter()
344                    .map(|x| x.expr)
345                    .collect::<Vec<_>>(),
346                order_by: order_by
347                    .unwrap_or_default()
348                    .into_iter()
349                    .map(|x| x.into())
350                    .collect::<Vec<_>>(),
351                window_frame,
352                null_treatment: None,
353            },
354        }),
355    })
356}
357
358// Generates a [pyo3] wrapper for associated aggregate functions.
359// All of the builder options are exposed to the python internal
360// function and we rely on the wrappers to only use those that
361// are appropriate.
362macro_rules! aggregate_function {
363    ($NAME: ident) => {
364        aggregate_function!($NAME, expr);
365    };
366    ($NAME: ident, $($arg:ident)*) => {
367        #[pyfunction]
368        #[pyo3(signature = ($($arg),*, distinct=None, filter=None, order_by=None, null_treatment=None))]
369        fn $NAME(
370            $($arg: PyExpr),*,
371            distinct: Option<bool>,
372            filter: Option<PyExpr>,
373            order_by: Option<Vec<PySortExpr>>,
374            null_treatment: Option<NullTreatment>
375        ) -> PyDataFusionResult<PyExpr> {
376            let agg_fn = functions_aggregate::expr_fn::$NAME($($arg.into()),*);
377
378            add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
379        }
380    };
381}
382
383/// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn]
384///
385/// These functions have explicit named arguments.
386macro_rules! expr_fn {
387    ($FUNC: ident) => {
388        expr_fn!($FUNC, , stringify!($FUNC));
389    };
390    ($FUNC:ident, $($arg:ident)*) => {
391        expr_fn!($FUNC, $($arg)*, stringify!($FUNC));
392    };
393    ($FUNC: ident, $DOC: expr) => {
394        expr_fn!($FUNC, ,$DOC);
395    };
396    ($FUNC: ident, $($arg:ident)*, $DOC: expr) => {
397        #[doc = $DOC]
398        #[pyfunction]
399        fn $FUNC($($arg: PyExpr),*) -> PyExpr {
400            functions::expr_fn::$FUNC($($arg.into()),*).into()
401        }
402    };
403}
404/// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn]
405///
406/// These functions take a single `Vec<PyExpr>` argument using `pyo3(signature = (*args))`.
407macro_rules! expr_fn_vec {
408    ($FUNC: ident) => {
409        expr_fn_vec!($FUNC, stringify!($FUNC));
410    };
411    ($FUNC: ident, $DOC: expr) => {
412        #[doc = $DOC]
413        #[pyfunction]
414        #[pyo3(signature = (*args))]
415        fn $FUNC(args: Vec<PyExpr>) -> PyExpr {
416            let args = args.into_iter().map(|e| e.into()).collect::<Vec<_>>();
417            functions::expr_fn::$FUNC(args).into()
418        }
419    };
420}
421
422/// Generates a [pyo3] wrapper for [datafusion_functions_nested::expr_fn]
423///
424/// These functions have explicit named arguments.
425macro_rules! array_fn {
426    ($FUNC: ident) => {
427        array_fn!($FUNC, , stringify!($FUNC));
428    };
429    ($FUNC:ident,  $($arg:ident)*) => {
430        array_fn!($FUNC, $($arg)*, stringify!($FUNC));
431    };
432    ($FUNC: ident, $DOC: expr) => {
433        array_fn!($FUNC, , $DOC);
434    };
435    ($FUNC: ident, $($arg:ident)*, $DOC:expr) => {
436        #[doc = $DOC]
437        #[pyfunction]
438        fn $FUNC($($arg: PyExpr),*) -> PyExpr {
439            datafusion::functions_nested::expr_fn::$FUNC($($arg.into()),*).into()
440        }
441    };
442}
443
444expr_fn!(abs, num);
445expr_fn!(acos, num);
446expr_fn!(acosh, num);
447expr_fn!(ascii, arg1, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character.");
448expr_fn!(asin, num);
449expr_fn!(asinh, num);
450expr_fn!(atan, num);
451expr_fn!(atanh, num);
452expr_fn!(atan2, y x);
453expr_fn!(
454    bit_length,
455    arg,
456    "Returns number of bits in the string (8 times the octet_length)."
457);
458expr_fn_vec!(btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string.");
459expr_fn!(cbrt, num);
460expr_fn!(ceil, num);
461expr_fn!(
462    character_length,
463    string,
464    "Returns number of characters in the string."
465);
466expr_fn!(length, string);
467expr_fn!(char_length, string);
468expr_fn!(chr, arg, "Returns the character with the given code.");
469expr_fn_vec!(coalesce);
470expr_fn!(cos, num);
471expr_fn!(cosh, num);
472expr_fn!(cot, num);
473expr_fn!(degrees, num);
474expr_fn!(decode, input encoding);
475expr_fn!(encode, input encoding);
476expr_fn!(ends_with, string suffix, "Returns true if string ends with suffix.");
477expr_fn!(exp, num);
478expr_fn!(factorial, num);
479expr_fn!(floor, num);
480expr_fn!(gcd, x y);
481expr_fn!(initcap, string, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.");
482expr_fn!(isnan, num);
483expr_fn!(iszero, num);
484expr_fn!(levenshtein, string1 string2);
485expr_fn!(lcm, x y);
486expr_fn!(left, string n, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters.");
487expr_fn!(ln, num);
488expr_fn!(log, base num);
489expr_fn!(log10, num);
490expr_fn!(log2, num);
491expr_fn!(lower, arg1, "Converts the string to all lower case");
492expr_fn_vec!(lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right).");
493expr_fn_vec!(ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string.");
494expr_fn!(
495    md5,
496    input_arg,
497    "Computes the MD5 hash of the argument, with the result written in hexadecimal."
498);
499expr_fn!(
500    nanvl,
501    x y,
502    "Returns x if x is not NaN otherwise returns y."
503);
504expr_fn!(
505    nvl,
506    x y,
507    "Returns x if x is not NULL otherwise returns y."
508);
509expr_fn!(nullif, arg_1 arg_2);
510expr_fn!(octet_length, args, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces.");
511expr_fn_vec!(overlay);
512expr_fn!(pi);
513expr_fn!(power, base exponent);
514expr_fn!(radians, num);
515expr_fn!(repeat, string n, "Repeats string the specified number of times.");
516expr_fn!(
517    replace,
518    string from to,
519    "Replaces all occurrences in string of substring from with substring to."
520);
521expr_fn!(
522    reverse,
523    string,
524    "Reverses the order of the characters in the string."
525);
526expr_fn!(right, string n, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters.");
527expr_fn_vec!(round);
528expr_fn_vec!(rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.");
529expr_fn_vec!(rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string.");
530expr_fn!(sha224, input_arg1);
531expr_fn!(sha256, input_arg1);
532expr_fn!(sha384, input_arg1);
533expr_fn!(sha512, input_arg1);
534expr_fn!(signum, num);
535expr_fn!(sin, num);
536expr_fn!(sinh, num);
537expr_fn!(
538    split_part,
539    string delimiter index,
540    "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."
541);
542expr_fn!(sqrt, num);
543expr_fn!(starts_with, string prefix, "Returns true if string starts with prefix.");
544expr_fn!(strpos, string substring, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)");
545expr_fn!(substr, string position);
546expr_fn!(substr_index, string delimiter count);
547expr_fn!(substring, string position length);
548expr_fn!(find_in_set, string string_list);
549expr_fn!(tan, num);
550expr_fn!(tanh, num);
551expr_fn!(
552    to_hex,
553    arg1,
554    "Converts the number to its equivalent hexadecimal representation."
555);
556expr_fn!(now);
557expr_fn_vec!(to_timestamp);
558expr_fn_vec!(to_timestamp_millis);
559expr_fn_vec!(to_timestamp_nanos);
560expr_fn_vec!(to_timestamp_micros);
561expr_fn_vec!(to_timestamp_seconds);
562expr_fn_vec!(to_unixtime);
563expr_fn!(current_date);
564expr_fn!(current_time);
565expr_fn!(date_part, part date);
566expr_fn!(date_trunc, part date);
567expr_fn!(date_bin, stride source origin);
568expr_fn!(make_date, year month day);
569
570expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.");
571expr_fn_vec!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string.");
572expr_fn_vec!(trunc);
573expr_fn!(upper, arg1, "Converts the string to all upper case.");
574expr_fn!(uuid);
575expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword
576expr_fn_vec!(named_struct);
577expr_fn!(from_unixtime, unixtime);
578expr_fn!(arrow_typeof, arg_1);
579expr_fn!(arrow_cast, arg_1 datatype);
580expr_fn!(random);
581
582// Array Functions
583array_fn!(array_append, array element);
584array_fn!(array_to_string, array delimiter);
585array_fn!(array_dims, array);
586array_fn!(array_distinct, array);
587array_fn!(array_element, array element);
588array_fn!(array_empty, array);
589array_fn!(array_length, array);
590array_fn!(array_has, first_array second_array);
591array_fn!(array_has_all, first_array second_array);
592array_fn!(array_has_any, first_array second_array);
593array_fn!(array_positions, array element);
594array_fn!(array_ndims, array);
595array_fn!(array_prepend, element array);
596array_fn!(array_pop_back, array);
597array_fn!(array_pop_front, array);
598array_fn!(array_remove, array element);
599array_fn!(array_remove_n, array element max);
600array_fn!(array_remove_all, array element);
601array_fn!(array_repeat, element count);
602array_fn!(array_replace, array from to);
603array_fn!(array_replace_n, array from to max);
604array_fn!(array_replace_all, array from to);
605array_fn!(array_sort, array desc null_first);
606array_fn!(array_intersect, first_array second_array);
607array_fn!(array_union, array1 array2);
608array_fn!(array_except, first_array second_array);
609array_fn!(array_resize, array size value);
610array_fn!(cardinality, array);
611array_fn!(flatten, array);
612array_fn!(range, start stop step);
613
614aggregate_function!(array_agg);
615aggregate_function!(max);
616aggregate_function!(min);
617aggregate_function!(avg);
618aggregate_function!(sum);
619aggregate_function!(bit_and);
620aggregate_function!(bit_or);
621aggregate_function!(bit_xor);
622aggregate_function!(bool_and);
623aggregate_function!(bool_or);
624aggregate_function!(corr, y x);
625aggregate_function!(count);
626aggregate_function!(covar_samp, y x);
627aggregate_function!(covar_pop, y x);
628aggregate_function!(median);
629aggregate_function!(regr_slope, y x);
630aggregate_function!(regr_intercept, y x);
631aggregate_function!(regr_count, y x);
632aggregate_function!(regr_r2, y x);
633aggregate_function!(regr_avgx, y x);
634aggregate_function!(regr_avgy, y x);
635aggregate_function!(regr_sxx, y x);
636aggregate_function!(regr_syy, y x);
637aggregate_function!(regr_sxy, y x);
638aggregate_function!(stddev);
639aggregate_function!(stddev_pop);
640aggregate_function!(var_sample);
641aggregate_function!(var_pop);
642aggregate_function!(approx_distinct);
643aggregate_function!(approx_median);
644
645// Code is commented out since grouping is not yet implemented
646// https://github.com/apache/datafusion-python/issues/861
647// aggregate_function!(grouping);
648
649#[pyfunction]
650#[pyo3(signature = (expression, percentile, num_centroids=None, filter=None))]
651pub fn approx_percentile_cont(
652    expression: PyExpr,
653    percentile: f64,
654    num_centroids: Option<i64>, // enforces optional arguments at the end, currently
655    filter: Option<PyExpr>,
656) -> PyDataFusionResult<PyExpr> {
657    let args = if let Some(num_centroids) = num_centroids {
658        vec![expression.expr, lit(percentile), lit(num_centroids)]
659    } else {
660        vec![expression.expr, lit(percentile)]
661    };
662    let udaf = functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf();
663    let agg_fn = udaf.call(args);
664
665    add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
666}
667
668#[pyfunction]
669#[pyo3(signature = (expression, weight, percentile, filter=None))]
670pub fn approx_percentile_cont_with_weight(
671    expression: PyExpr,
672    weight: PyExpr,
673    percentile: f64,
674    filter: Option<PyExpr>,
675) -> PyDataFusionResult<PyExpr> {
676    let agg_fn = functions_aggregate::expr_fn::approx_percentile_cont_with_weight(
677        expression.expr,
678        weight.expr,
679        lit(percentile),
680    );
681
682    add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
683}
684
685// We handle first_value explicitly because the signature expects an order_by
686// https://github.com/apache/datafusion/issues/12376
687#[pyfunction]
688#[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))]
689pub fn last_value(
690    expr: PyExpr,
691    distinct: Option<bool>,
692    filter: Option<PyExpr>,
693    order_by: Option<Vec<PySortExpr>>,
694    null_treatment: Option<NullTreatment>,
695) -> PyDataFusionResult<PyExpr> {
696    // If we initialize the UDAF with order_by directly, then it gets over-written by the builder
697    let agg_fn = functions_aggregate::expr_fn::last_value(expr.expr, None);
698
699    add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
700}
701// We handle first_value explicitly because the signature expects an order_by
702// https://github.com/apache/datafusion/issues/12376
703#[pyfunction]
704#[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))]
705pub fn first_value(
706    expr: PyExpr,
707    distinct: Option<bool>,
708    filter: Option<PyExpr>,
709    order_by: Option<Vec<PySortExpr>>,
710    null_treatment: Option<NullTreatment>,
711) -> PyDataFusionResult<PyExpr> {
712    // If we initialize the UDAF with order_by directly, then it gets over-written by the builder
713    let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
714
715    add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
716}
717
718// nth_value requires a non-expr argument
719#[pyfunction]
720#[pyo3(signature = (expr, n, distinct=None, filter=None, order_by=None, null_treatment=None))]
721pub fn nth_value(
722    expr: PyExpr,
723    n: i64,
724    distinct: Option<bool>,
725    filter: Option<PyExpr>,
726    order_by: Option<Vec<PySortExpr>>,
727    null_treatment: Option<NullTreatment>,
728) -> PyDataFusionResult<PyExpr> {
729    let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(expr.expr, n, vec![]);
730    add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
731}
732
733// string_agg requires a non-expr argument
734#[pyfunction]
735#[pyo3(signature = (expr, delimiter, distinct=None, filter=None, order_by=None, null_treatment=None))]
736pub fn string_agg(
737    expr: PyExpr,
738    delimiter: String,
739    distinct: Option<bool>,
740    filter: Option<PyExpr>,
741    order_by: Option<Vec<PySortExpr>>,
742    null_treatment: Option<NullTreatment>,
743) -> PyDataFusionResult<PyExpr> {
744    let agg_fn = datafusion::functions_aggregate::string_agg::string_agg(expr.expr, lit(delimiter));
745    add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
746}
747
748pub(crate) fn add_builder_fns_to_window(
749    window_fn: Expr,
750    partition_by: Option<Vec<PyExpr>>,
751    window_frame: Option<PyWindowFrame>,
752    order_by: Option<Vec<PySortExpr>>,
753    null_treatment: Option<NullTreatment>,
754) -> PyDataFusionResult<PyExpr> {
755    let null_treatment = null_treatment.map(|n| n.into());
756    let mut builder = window_fn.null_treatment(null_treatment);
757
758    if let Some(partition_cols) = partition_by {
759        builder = builder.partition_by(
760            partition_cols
761                .into_iter()
762                .map(|col| col.clone().into())
763                .collect(),
764        );
765    }
766
767    if let Some(order_by_cols) = order_by {
768        let order_by_cols = to_sort_expressions(order_by_cols);
769        builder = builder.order_by(order_by_cols);
770    }
771
772    if let Some(window_frame) = window_frame {
773        builder = builder.window_frame(window_frame.into());
774    }
775
776    Ok(builder.build().map(|e| e.into())?)
777}
778
779#[pyfunction]
780#[pyo3(signature = (arg, shift_offset, default_value=None, partition_by=None, order_by=None))]
781pub fn lead(
782    arg: PyExpr,
783    shift_offset: i64,
784    default_value: Option<PyScalarValue>,
785    partition_by: Option<Vec<PyExpr>>,
786    order_by: Option<Vec<PySortExpr>>,
787) -> PyDataFusionResult<PyExpr> {
788    let default_value = default_value.map(|v| v.into());
789    let window_fn = functions_window::expr_fn::lead(arg.expr, Some(shift_offset), default_value);
790
791    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
792}
793
794#[pyfunction]
795#[pyo3(signature = (arg, shift_offset, default_value=None, partition_by=None, order_by=None))]
796pub fn lag(
797    arg: PyExpr,
798    shift_offset: i64,
799    default_value: Option<PyScalarValue>,
800    partition_by: Option<Vec<PyExpr>>,
801    order_by: Option<Vec<PySortExpr>>,
802) -> PyDataFusionResult<PyExpr> {
803    let default_value = default_value.map(|v| v.into());
804    let window_fn = functions_window::expr_fn::lag(arg.expr, Some(shift_offset), default_value);
805
806    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
807}
808
809#[pyfunction]
810#[pyo3(signature = (partition_by=None, order_by=None))]
811pub fn row_number(
812    partition_by: Option<Vec<PyExpr>>,
813    order_by: Option<Vec<PySortExpr>>,
814) -> PyDataFusionResult<PyExpr> {
815    let window_fn = functions_window::expr_fn::row_number();
816
817    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
818}
819
820#[pyfunction]
821#[pyo3(signature = (partition_by=None, order_by=None))]
822pub fn rank(
823    partition_by: Option<Vec<PyExpr>>,
824    order_by: Option<Vec<PySortExpr>>,
825) -> PyDataFusionResult<PyExpr> {
826    let window_fn = functions_window::expr_fn::rank();
827
828    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
829}
830
831#[pyfunction]
832#[pyo3(signature = (partition_by=None, order_by=None))]
833pub fn dense_rank(
834    partition_by: Option<Vec<PyExpr>>,
835    order_by: Option<Vec<PySortExpr>>,
836) -> PyDataFusionResult<PyExpr> {
837    let window_fn = functions_window::expr_fn::dense_rank();
838
839    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
840}
841
842#[pyfunction]
843#[pyo3(signature = (partition_by=None, order_by=None))]
844pub fn percent_rank(
845    partition_by: Option<Vec<PyExpr>>,
846    order_by: Option<Vec<PySortExpr>>,
847) -> PyDataFusionResult<PyExpr> {
848    let window_fn = functions_window::expr_fn::percent_rank();
849
850    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
851}
852
853#[pyfunction]
854#[pyo3(signature = (partition_by=None, order_by=None))]
855pub fn cume_dist(
856    partition_by: Option<Vec<PyExpr>>,
857    order_by: Option<Vec<PySortExpr>>,
858) -> PyDataFusionResult<PyExpr> {
859    let window_fn = functions_window::expr_fn::cume_dist();
860
861    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
862}
863
864#[pyfunction]
865#[pyo3(signature = (arg, partition_by=None, order_by=None))]
866pub fn ntile(
867    arg: PyExpr,
868    partition_by: Option<Vec<PyExpr>>,
869    order_by: Option<Vec<PySortExpr>>,
870) -> PyDataFusionResult<PyExpr> {
871    let window_fn = functions_window::expr_fn::ntile(arg.into());
872
873    add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
874}
875
876pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
877    m.add_wrapped(wrap_pyfunction!(abs))?;
878    m.add_wrapped(wrap_pyfunction!(acos))?;
879    m.add_wrapped(wrap_pyfunction!(acosh))?;
880    m.add_wrapped(wrap_pyfunction!(approx_distinct))?;
881    m.add_wrapped(wrap_pyfunction!(alias))?;
882    m.add_wrapped(wrap_pyfunction!(approx_median))?;
883    m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
884    m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
885    m.add_wrapped(wrap_pyfunction!(range))?;
886    m.add_wrapped(wrap_pyfunction!(array_agg))?;
887    m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
888    m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
889    m.add_wrapped(wrap_pyfunction!(ascii))?;
890    m.add_wrapped(wrap_pyfunction!(asin))?;
891    m.add_wrapped(wrap_pyfunction!(asinh))?;
892    m.add_wrapped(wrap_pyfunction!(atan))?;
893    m.add_wrapped(wrap_pyfunction!(atanh))?;
894    m.add_wrapped(wrap_pyfunction!(atan2))?;
895    m.add_wrapped(wrap_pyfunction!(avg))?;
896    m.add_wrapped(wrap_pyfunction!(bit_length))?;
897    m.add_wrapped(wrap_pyfunction!(btrim))?;
898    m.add_wrapped(wrap_pyfunction!(cbrt))?;
899    m.add_wrapped(wrap_pyfunction!(ceil))?;
900    m.add_wrapped(wrap_pyfunction!(character_length))?;
901    m.add_wrapped(wrap_pyfunction!(chr))?;
902    m.add_wrapped(wrap_pyfunction!(char_length))?;
903    m.add_wrapped(wrap_pyfunction!(coalesce))?;
904    m.add_wrapped(wrap_pyfunction!(case))?;
905    m.add_wrapped(wrap_pyfunction!(when))?;
906    m.add_wrapped(wrap_pyfunction!(col))?;
907    m.add_wrapped(wrap_pyfunction!(concat_ws))?;
908    m.add_wrapped(wrap_pyfunction!(concat))?;
909    m.add_wrapped(wrap_pyfunction!(corr))?;
910    m.add_wrapped(wrap_pyfunction!(cos))?;
911    m.add_wrapped(wrap_pyfunction!(cosh))?;
912    m.add_wrapped(wrap_pyfunction!(cot))?;
913    m.add_wrapped(wrap_pyfunction!(count))?;
914    m.add_wrapped(wrap_pyfunction!(covar_pop))?;
915    m.add_wrapped(wrap_pyfunction!(covar_samp))?;
916    m.add_wrapped(wrap_pyfunction!(current_date))?;
917    m.add_wrapped(wrap_pyfunction!(current_time))?;
918    m.add_wrapped(wrap_pyfunction!(degrees))?;
919    m.add_wrapped(wrap_pyfunction!(date_bin))?;
920    m.add_wrapped(wrap_pyfunction!(date_part))?;
921    m.add_wrapped(wrap_pyfunction!(date_trunc))?;
922    m.add_wrapped(wrap_pyfunction!(make_date))?;
923    m.add_wrapped(wrap_pyfunction!(digest))?;
924    m.add_wrapped(wrap_pyfunction!(ends_with))?;
925    m.add_wrapped(wrap_pyfunction!(exp))?;
926    m.add_wrapped(wrap_pyfunction!(factorial))?;
927    m.add_wrapped(wrap_pyfunction!(floor))?;
928    m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
929    m.add_wrapped(wrap_pyfunction!(gcd))?;
930    // m.add_wrapped(wrap_pyfunction!(grouping))?;
931    m.add_wrapped(wrap_pyfunction!(in_list))?;
932    m.add_wrapped(wrap_pyfunction!(initcap))?;
933    m.add_wrapped(wrap_pyfunction!(isnan))?;
934    m.add_wrapped(wrap_pyfunction!(iszero))?;
935    m.add_wrapped(wrap_pyfunction!(levenshtein))?;
936    m.add_wrapped(wrap_pyfunction!(lcm))?;
937    m.add_wrapped(wrap_pyfunction!(left))?;
938    m.add_wrapped(wrap_pyfunction!(length))?;
939    m.add_wrapped(wrap_pyfunction!(ln))?;
940    m.add_wrapped(wrap_pyfunction!(log))?;
941    m.add_wrapped(wrap_pyfunction!(log10))?;
942    m.add_wrapped(wrap_pyfunction!(log2))?;
943    m.add_wrapped(wrap_pyfunction!(lower))?;
944    m.add_wrapped(wrap_pyfunction!(lpad))?;
945    m.add_wrapped(wrap_pyfunction!(ltrim))?;
946    m.add_wrapped(wrap_pyfunction!(max))?;
947    m.add_wrapped(wrap_pyfunction!(make_array))?;
948    m.add_wrapped(wrap_pyfunction!(md5))?;
949    m.add_wrapped(wrap_pyfunction!(median))?;
950    m.add_wrapped(wrap_pyfunction!(min))?;
951    m.add_wrapped(wrap_pyfunction!(named_struct))?;
952    m.add_wrapped(wrap_pyfunction!(nanvl))?;
953    m.add_wrapped(wrap_pyfunction!(nvl))?;
954    m.add_wrapped(wrap_pyfunction!(now))?;
955    m.add_wrapped(wrap_pyfunction!(nullif))?;
956    m.add_wrapped(wrap_pyfunction!(octet_length))?;
957    m.add_wrapped(wrap_pyfunction!(order_by))?;
958    m.add_wrapped(wrap_pyfunction!(overlay))?;
959    m.add_wrapped(wrap_pyfunction!(pi))?;
960    m.add_wrapped(wrap_pyfunction!(power))?;
961    m.add_wrapped(wrap_pyfunction!(radians))?;
962    m.add_wrapped(wrap_pyfunction!(random))?;
963    m.add_wrapped(wrap_pyfunction!(regexp_count))?;
964    m.add_wrapped(wrap_pyfunction!(regexp_like))?;
965    m.add_wrapped(wrap_pyfunction!(regexp_match))?;
966    m.add_wrapped(wrap_pyfunction!(regexp_replace))?;
967    m.add_wrapped(wrap_pyfunction!(repeat))?;
968    m.add_wrapped(wrap_pyfunction!(replace))?;
969    m.add_wrapped(wrap_pyfunction!(reverse))?;
970    m.add_wrapped(wrap_pyfunction!(right))?;
971    m.add_wrapped(wrap_pyfunction!(round))?;
972    m.add_wrapped(wrap_pyfunction!(rpad))?;
973    m.add_wrapped(wrap_pyfunction!(rtrim))?;
974    m.add_wrapped(wrap_pyfunction!(sha224))?;
975    m.add_wrapped(wrap_pyfunction!(sha256))?;
976    m.add_wrapped(wrap_pyfunction!(sha384))?;
977    m.add_wrapped(wrap_pyfunction!(sha512))?;
978    m.add_wrapped(wrap_pyfunction!(signum))?;
979    m.add_wrapped(wrap_pyfunction!(sin))?;
980    m.add_wrapped(wrap_pyfunction!(sinh))?;
981    m.add_wrapped(wrap_pyfunction!(split_part))?;
982    m.add_wrapped(wrap_pyfunction!(sqrt))?;
983    m.add_wrapped(wrap_pyfunction!(starts_with))?;
984    m.add_wrapped(wrap_pyfunction!(stddev))?;
985    m.add_wrapped(wrap_pyfunction!(stddev_pop))?;
986    m.add_wrapped(wrap_pyfunction!(string_agg))?;
987    m.add_wrapped(wrap_pyfunction!(strpos))?;
988    m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since struct is a keyword
989    m.add_wrapped(wrap_pyfunction!(substr))?;
990    m.add_wrapped(wrap_pyfunction!(substr_index))?;
991    m.add_wrapped(wrap_pyfunction!(substring))?;
992    m.add_wrapped(wrap_pyfunction!(find_in_set))?;
993    m.add_wrapped(wrap_pyfunction!(sum))?;
994    m.add_wrapped(wrap_pyfunction!(tan))?;
995    m.add_wrapped(wrap_pyfunction!(tanh))?;
996    m.add_wrapped(wrap_pyfunction!(to_hex))?;
997    m.add_wrapped(wrap_pyfunction!(to_timestamp))?;
998    m.add_wrapped(wrap_pyfunction!(to_timestamp_millis))?;
999    m.add_wrapped(wrap_pyfunction!(to_timestamp_nanos))?;
1000    m.add_wrapped(wrap_pyfunction!(to_timestamp_micros))?;
1001    m.add_wrapped(wrap_pyfunction!(to_timestamp_seconds))?;
1002    m.add_wrapped(wrap_pyfunction!(to_unixtime))?;
1003    m.add_wrapped(wrap_pyfunction!(translate))?;
1004    m.add_wrapped(wrap_pyfunction!(trim))?;
1005    m.add_wrapped(wrap_pyfunction!(trunc))?;
1006    m.add_wrapped(wrap_pyfunction!(upper))?;
1007    m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
1008    m.add_wrapped(wrap_pyfunction!(var_pop))?;
1009    m.add_wrapped(wrap_pyfunction!(var_sample))?;
1010    m.add_wrapped(wrap_pyfunction!(window))?;
1011    m.add_wrapped(wrap_pyfunction!(regr_avgx))?;
1012    m.add_wrapped(wrap_pyfunction!(regr_avgy))?;
1013    m.add_wrapped(wrap_pyfunction!(regr_count))?;
1014    m.add_wrapped(wrap_pyfunction!(regr_intercept))?;
1015    m.add_wrapped(wrap_pyfunction!(regr_r2))?;
1016    m.add_wrapped(wrap_pyfunction!(regr_slope))?;
1017    m.add_wrapped(wrap_pyfunction!(regr_sxx))?;
1018    m.add_wrapped(wrap_pyfunction!(regr_sxy))?;
1019    m.add_wrapped(wrap_pyfunction!(regr_syy))?;
1020    m.add_wrapped(wrap_pyfunction!(first_value))?;
1021    m.add_wrapped(wrap_pyfunction!(last_value))?;
1022    m.add_wrapped(wrap_pyfunction!(nth_value))?;
1023    m.add_wrapped(wrap_pyfunction!(bit_and))?;
1024    m.add_wrapped(wrap_pyfunction!(bit_or))?;
1025    m.add_wrapped(wrap_pyfunction!(bit_xor))?;
1026    m.add_wrapped(wrap_pyfunction!(bool_and))?;
1027    m.add_wrapped(wrap_pyfunction!(bool_or))?;
1028
1029    //Binary String Functions
1030    m.add_wrapped(wrap_pyfunction!(encode))?;
1031    m.add_wrapped(wrap_pyfunction!(decode))?;
1032
1033    // Array Functions
1034    m.add_wrapped(wrap_pyfunction!(array_append))?;
1035    m.add_wrapped(wrap_pyfunction!(array_concat))?;
1036    m.add_wrapped(wrap_pyfunction!(array_cat))?;
1037    m.add_wrapped(wrap_pyfunction!(array_dims))?;
1038    m.add_wrapped(wrap_pyfunction!(array_distinct))?;
1039    m.add_wrapped(wrap_pyfunction!(array_element))?;
1040    m.add_wrapped(wrap_pyfunction!(array_empty))?;
1041    m.add_wrapped(wrap_pyfunction!(array_length))?;
1042    m.add_wrapped(wrap_pyfunction!(array_has))?;
1043    m.add_wrapped(wrap_pyfunction!(array_has_all))?;
1044    m.add_wrapped(wrap_pyfunction!(array_has_any))?;
1045    m.add_wrapped(wrap_pyfunction!(array_position))?;
1046    m.add_wrapped(wrap_pyfunction!(array_positions))?;
1047    m.add_wrapped(wrap_pyfunction!(array_to_string))?;
1048    m.add_wrapped(wrap_pyfunction!(array_intersect))?;
1049    m.add_wrapped(wrap_pyfunction!(array_union))?;
1050    m.add_wrapped(wrap_pyfunction!(array_except))?;
1051    m.add_wrapped(wrap_pyfunction!(array_resize))?;
1052    m.add_wrapped(wrap_pyfunction!(array_ndims))?;
1053    m.add_wrapped(wrap_pyfunction!(array_prepend))?;
1054    m.add_wrapped(wrap_pyfunction!(array_pop_back))?;
1055    m.add_wrapped(wrap_pyfunction!(array_pop_front))?;
1056    m.add_wrapped(wrap_pyfunction!(array_remove))?;
1057    m.add_wrapped(wrap_pyfunction!(array_remove_n))?;
1058    m.add_wrapped(wrap_pyfunction!(array_remove_all))?;
1059    m.add_wrapped(wrap_pyfunction!(array_repeat))?;
1060    m.add_wrapped(wrap_pyfunction!(array_replace))?;
1061    m.add_wrapped(wrap_pyfunction!(array_replace_n))?;
1062    m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
1063    m.add_wrapped(wrap_pyfunction!(array_sort))?;
1064    m.add_wrapped(wrap_pyfunction!(array_slice))?;
1065    m.add_wrapped(wrap_pyfunction!(flatten))?;
1066    m.add_wrapped(wrap_pyfunction!(cardinality))?;
1067
1068    // Window Functions
1069    m.add_wrapped(wrap_pyfunction!(lead))?;
1070    m.add_wrapped(wrap_pyfunction!(lag))?;
1071    m.add_wrapped(wrap_pyfunction!(rank))?;
1072    m.add_wrapped(wrap_pyfunction!(row_number))?;
1073    m.add_wrapped(wrap_pyfunction!(dense_rank))?;
1074    m.add_wrapped(wrap_pyfunction!(percent_rank))?;
1075    m.add_wrapped(wrap_pyfunction!(cume_dist))?;
1076    m.add_wrapped(wrap_pyfunction!(ntile))?;
1077
1078    Ok(())
1079}