1use std::collections::HashMap;
19
20use datafusion::functions_aggregate::all_default_aggregate_functions;
21use datafusion::functions_window::all_default_window_functions;
22use datafusion::logical_expr::expr::WindowFunctionParams;
23use datafusion::logical_expr::ExprFunctionExt;
24use datafusion::logical_expr::WindowFrame;
25use pyo3::{prelude::*, wrap_pyfunction};
26
27use crate::common::data_type::NullTreatment;
28use crate::common::data_type::PyScalarValue;
29use crate::context::PySessionContext;
30use crate::errors::PyDataFusionError;
31use crate::errors::PyDataFusionResult;
32use crate::expr::conditional_expr::PyCaseBuilder;
33use crate::expr::sort_expr::to_sort_expressions;
34use crate::expr::sort_expr::PySortExpr;
35use crate::expr::window::PyWindowFrame;
36use crate::expr::PyExpr;
37use datafusion::common::{Column, ScalarValue, TableReference};
38use datafusion::execution::FunctionRegistry;
39use datafusion::functions;
40use datafusion::functions_aggregate;
41use datafusion::functions_window;
42use datafusion::logical_expr::expr::Alias;
43use datafusion::logical_expr::sqlparser::ast::NullTreatment as DFNullTreatment;
44use datafusion::logical_expr::{expr::WindowFunction, lit, Expr, WindowFunctionDefinition};
45
46fn add_builder_fns_to_aggregate(
47 agg_fn: Expr,
48 distinct: Option<bool>,
49 filter: Option<PyExpr>,
50 order_by: Option<Vec<PySortExpr>>,
51 null_treatment: Option<NullTreatment>,
52) -> PyDataFusionResult<PyExpr> {
53 let mut builder = agg_fn.null_treatment(None);
56
57 if let Some(order_by_cols) = order_by {
58 let order_by_cols = to_sort_expressions(order_by_cols);
59 builder = builder.order_by(order_by_cols);
60 }
61
62 if let Some(true) = distinct {
63 builder = builder.distinct();
64 }
65
66 if let Some(filter) = filter {
67 builder = builder.filter(filter.expr);
68 }
69
70 builder = builder.null_treatment(null_treatment.map(DFNullTreatment::from));
71
72 Ok(builder.build()?.into())
73}
74
75#[pyfunction]
76fn in_list(expr: PyExpr, value: Vec<PyExpr>, negated: bool) -> PyExpr {
77 datafusion::logical_expr::in_list(
78 expr.expr,
79 value.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
80 negated,
81 )
82 .into()
83}
84
85#[pyfunction]
86fn make_array(exprs: Vec<PyExpr>) -> PyExpr {
87 datafusion::functions_nested::expr_fn::make_array(exprs.into_iter().map(|x| x.into()).collect())
88 .into()
89}
90
91#[pyfunction]
92fn array_concat(exprs: Vec<PyExpr>) -> PyExpr {
93 let exprs = exprs.into_iter().map(|x| x.into()).collect();
94 datafusion::functions_nested::expr_fn::array_concat(exprs).into()
95}
96
97#[pyfunction]
98fn array_cat(exprs: Vec<PyExpr>) -> PyExpr {
99 array_concat(exprs)
100}
101
102#[pyfunction]
103#[pyo3(signature = (array, element, index=None))]
104fn array_position(array: PyExpr, element: PyExpr, index: Option<i64>) -> PyExpr {
105 let index = ScalarValue::Int64(index);
106 let index = Expr::Literal(index);
107 datafusion::functions_nested::expr_fn::array_position(array.into(), element.into(), index)
108 .into()
109}
110
111#[pyfunction]
112#[pyo3(signature = (array, begin, end, stride=None))]
113fn array_slice(array: PyExpr, begin: PyExpr, end: PyExpr, stride: Option<PyExpr>) -> PyExpr {
114 datafusion::functions_nested::expr_fn::array_slice(
115 array.into(),
116 begin.into(),
117 end.into(),
118 stride.map(Into::into),
119 )
120 .into()
121}
122
123#[pyfunction]
127fn digest(value: PyExpr, method: PyExpr) -> PyExpr {
128 PyExpr {
129 expr: functions::expr_fn::digest(value.expr, method.expr),
130 }
131}
132
133#[pyfunction]
136fn concat(args: Vec<PyExpr>) -> PyResult<PyExpr> {
137 let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
138 Ok(functions::string::expr_fn::concat(args).into())
139}
140
141#[pyfunction]
145fn concat_ws(sep: String, args: Vec<PyExpr>) -> PyResult<PyExpr> {
146 let args = args.into_iter().map(|e| e.expr).collect::<Vec<_>>();
147 Ok(functions::string::expr_fn::concat_ws(lit(sep), args).into())
148}
149
150#[pyfunction]
151#[pyo3(signature = (values, regex, flags=None))]
152fn regexp_like(values: PyExpr, regex: PyExpr, flags: Option<PyExpr>) -> PyResult<PyExpr> {
153 Ok(functions::expr_fn::regexp_like(values.expr, regex.expr, flags.map(|x| x.expr)).into())
154}
155
156#[pyfunction]
157#[pyo3(signature = (values, regex, flags=None))]
158fn regexp_match(values: PyExpr, regex: PyExpr, flags: Option<PyExpr>) -> PyResult<PyExpr> {
159 Ok(functions::expr_fn::regexp_match(values.expr, regex.expr, flags.map(|x| x.expr)).into())
160}
161
162#[pyfunction]
163#[pyo3(signature = (string, pattern, replacement, flags=None))]
164fn regexp_replace(
166 string: PyExpr,
167 pattern: PyExpr,
168 replacement: PyExpr,
169 flags: Option<PyExpr>,
170) -> PyResult<PyExpr> {
171 Ok(functions::expr_fn::regexp_replace(
172 string.into(),
173 pattern.into(),
174 replacement.into(),
175 flags.map(|x| x.expr),
176 )
177 .into())
178}
179
180#[pyfunction]
181#[pyo3(signature = (string, pattern, start, flags=None))]
182fn regexp_count(
184 string: PyExpr,
185 pattern: PyExpr,
186 start: Option<PyExpr>,
187 flags: Option<PyExpr>,
188) -> PyResult<PyExpr> {
189 Ok(functions::expr_fn::regexp_count(
190 string.expr,
191 pattern.expr,
192 start.map(|x| x.expr),
193 flags.map(|x| x.expr),
194 )
195 .into())
196}
197
198#[pyfunction]
200fn order_by(expr: PyExpr, asc: bool, nulls_first: bool) -> PyResult<PySortExpr> {
201 Ok(PySortExpr::from(datafusion::logical_expr::expr::Sort {
202 expr: expr.expr,
203 asc,
204 nulls_first,
205 }))
206}
207
208#[pyfunction]
210#[pyo3(signature = (expr, name, metadata=None))]
211fn alias(expr: PyExpr, name: &str, metadata: Option<HashMap<String, String>>) -> PyResult<PyExpr> {
212 let relation: Option<TableReference> = None;
213 Ok(PyExpr {
214 expr: datafusion::logical_expr::Expr::Alias(
215 Alias::new(expr.expr, relation, name).with_metadata(metadata),
216 ),
217 })
218}
219
220#[pyfunction]
222fn col(name: &str) -> PyResult<PyExpr> {
223 Ok(PyExpr {
224 expr: datafusion::logical_expr::Expr::Column(Column::new_unqualified(name)),
225 })
226}
227
228#[pyfunction]
230fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
231 Ok(PyCaseBuilder {
232 case_builder: datafusion::logical_expr::case(expr.expr),
233 })
234}
235
236#[pyfunction]
238fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
239 Ok(PyCaseBuilder {
240 case_builder: datafusion::logical_expr::when(when.expr, then.expr),
241 })
242}
243
244fn find_window_fn(
256 name: &str,
257 ctx: Option<PySessionContext>,
258) -> PyDataFusionResult<WindowFunctionDefinition> {
259 if let Some(ctx) = ctx {
260 let udaf = ctx
262 .ctx
263 .udaf(name)
264 .map(WindowFunctionDefinition::AggregateUDF)
265 .ok();
266
267 if let Some(udaf) = udaf {
268 return Ok(udaf);
269 }
270
271 let session_state = ctx.ctx.state();
272
273 let window_fn = session_state
275 .window_functions()
276 .get(name)
277 .map(|f| WindowFunctionDefinition::WindowUDF(f.clone()));
278
279 if let Some(window_fn) = window_fn {
280 return Ok(window_fn);
281 }
282
283 let agg_fn = session_state
285 .aggregate_functions()
286 .get(name)
287 .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()));
288
289 if let Some(agg_fn) = agg_fn {
290 return Ok(agg_fn);
291 }
292 }
293
294 let agg_fn = all_default_aggregate_functions()
296 .iter()
297 .find(|v| v.name() == name || v.aliases().contains(&name.to_string()))
298 .map(|f| WindowFunctionDefinition::AggregateUDF(f.clone()));
299
300 if let Some(agg_fn) = agg_fn {
301 return Ok(agg_fn);
302 }
303
304 let window_fn = all_default_window_functions()
306 .iter()
307 .find(|v| v.name() == name || v.aliases().contains(&name.to_string()))
308 .map(|f| WindowFunctionDefinition::WindowUDF(f.clone()));
309
310 if let Some(window_fn) = window_fn {
311 return Ok(window_fn);
312 }
313
314 Err(PyDataFusionError::Common(format!(
315 "window function `{name}` not found"
316 )))
317}
318
319#[pyfunction]
321#[pyo3(signature = (name, args, partition_by=None, order_by=None, window_frame=None, ctx=None))]
322fn window(
323 name: &str,
324 args: Vec<PyExpr>,
325 partition_by: Option<Vec<PyExpr>>,
326 order_by: Option<Vec<PySortExpr>>,
327 window_frame: Option<PyWindowFrame>,
328 ctx: Option<PySessionContext>,
329) -> PyResult<PyExpr> {
330 let fun = find_window_fn(name, ctx)?;
331
332 let window_frame = window_frame
333 .map(|w| w.into())
334 .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty())));
335
336 Ok(PyExpr {
337 expr: datafusion::logical_expr::Expr::WindowFunction(WindowFunction {
338 fun,
339 params: WindowFunctionParams {
340 args: args.into_iter().map(|x| x.expr).collect::<Vec<_>>(),
341 partition_by: partition_by
342 .unwrap_or_default()
343 .into_iter()
344 .map(|x| x.expr)
345 .collect::<Vec<_>>(),
346 order_by: order_by
347 .unwrap_or_default()
348 .into_iter()
349 .map(|x| x.into())
350 .collect::<Vec<_>>(),
351 window_frame,
352 null_treatment: None,
353 },
354 }),
355 })
356}
357
358macro_rules! aggregate_function {
363 ($NAME: ident) => {
364 aggregate_function!($NAME, expr);
365 };
366 ($NAME: ident, $($arg:ident)*) => {
367 #[pyfunction]
368 #[pyo3(signature = ($($arg),*, distinct=None, filter=None, order_by=None, null_treatment=None))]
369 fn $NAME(
370 $($arg: PyExpr),*,
371 distinct: Option<bool>,
372 filter: Option<PyExpr>,
373 order_by: Option<Vec<PySortExpr>>,
374 null_treatment: Option<NullTreatment>
375 ) -> PyDataFusionResult<PyExpr> {
376 let agg_fn = functions_aggregate::expr_fn::$NAME($($arg.into()),*);
377
378 add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
379 }
380 };
381}
382
383macro_rules! expr_fn {
387 ($FUNC: ident) => {
388 expr_fn!($FUNC, , stringify!($FUNC));
389 };
390 ($FUNC:ident, $($arg:ident)*) => {
391 expr_fn!($FUNC, $($arg)*, stringify!($FUNC));
392 };
393 ($FUNC: ident, $DOC: expr) => {
394 expr_fn!($FUNC, ,$DOC);
395 };
396 ($FUNC: ident, $($arg:ident)*, $DOC: expr) => {
397 #[doc = $DOC]
398 #[pyfunction]
399 fn $FUNC($($arg: PyExpr),*) -> PyExpr {
400 functions::expr_fn::$FUNC($($arg.into()),*).into()
401 }
402 };
403}
404macro_rules! expr_fn_vec {
408 ($FUNC: ident) => {
409 expr_fn_vec!($FUNC, stringify!($FUNC));
410 };
411 ($FUNC: ident, $DOC: expr) => {
412 #[doc = $DOC]
413 #[pyfunction]
414 #[pyo3(signature = (*args))]
415 fn $FUNC(args: Vec<PyExpr>) -> PyExpr {
416 let args = args.into_iter().map(|e| e.into()).collect::<Vec<_>>();
417 functions::expr_fn::$FUNC(args).into()
418 }
419 };
420}
421
422macro_rules! array_fn {
426 ($FUNC: ident) => {
427 array_fn!($FUNC, , stringify!($FUNC));
428 };
429 ($FUNC:ident, $($arg:ident)*) => {
430 array_fn!($FUNC, $($arg)*, stringify!($FUNC));
431 };
432 ($FUNC: ident, $DOC: expr) => {
433 array_fn!($FUNC, , $DOC);
434 };
435 ($FUNC: ident, $($arg:ident)*, $DOC:expr) => {
436 #[doc = $DOC]
437 #[pyfunction]
438 fn $FUNC($($arg: PyExpr),*) -> PyExpr {
439 datafusion::functions_nested::expr_fn::$FUNC($($arg.into()),*).into()
440 }
441 };
442}
443
444expr_fn!(abs, num);
445expr_fn!(acos, num);
446expr_fn!(acosh, num);
447expr_fn!(ascii, arg1, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character.");
448expr_fn!(asin, num);
449expr_fn!(asinh, num);
450expr_fn!(atan, num);
451expr_fn!(atanh, num);
452expr_fn!(atan2, y x);
453expr_fn!(
454 bit_length,
455 arg,
456 "Returns number of bits in the string (8 times the octet_length)."
457);
458expr_fn_vec!(btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string.");
459expr_fn!(cbrt, num);
460expr_fn!(ceil, num);
461expr_fn!(
462 character_length,
463 string,
464 "Returns number of characters in the string."
465);
466expr_fn!(length, string);
467expr_fn!(char_length, string);
468expr_fn!(chr, arg, "Returns the character with the given code.");
469expr_fn_vec!(coalesce);
470expr_fn!(cos, num);
471expr_fn!(cosh, num);
472expr_fn!(cot, num);
473expr_fn!(degrees, num);
474expr_fn!(decode, input encoding);
475expr_fn!(encode, input encoding);
476expr_fn!(ends_with, string suffix, "Returns true if string ends with suffix.");
477expr_fn!(exp, num);
478expr_fn!(factorial, num);
479expr_fn!(floor, num);
480expr_fn!(gcd, x y);
481expr_fn!(initcap, string, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.");
482expr_fn!(isnan, num);
483expr_fn!(iszero, num);
484expr_fn!(levenshtein, string1 string2);
485expr_fn!(lcm, x y);
486expr_fn!(left, string n, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters.");
487expr_fn!(ln, num);
488expr_fn!(log, base num);
489expr_fn!(log10, num);
490expr_fn!(log2, num);
491expr_fn!(lower, arg1, "Converts the string to all lower case");
492expr_fn_vec!(lpad, "Extends the string to length length by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right).");
493expr_fn_vec!(ltrim, "Removes the longest string containing only characters in characters (a space by default) from the start of string.");
494expr_fn!(
495 md5,
496 input_arg,
497 "Computes the MD5 hash of the argument, with the result written in hexadecimal."
498);
499expr_fn!(
500 nanvl,
501 x y,
502 "Returns x if x is not NaN otherwise returns y."
503);
504expr_fn!(
505 nvl,
506 x y,
507 "Returns x if x is not NULL otherwise returns y."
508);
509expr_fn!(nullif, arg_1 arg_2);
510expr_fn!(octet_length, args, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces.");
511expr_fn_vec!(overlay);
512expr_fn!(pi);
513expr_fn!(power, base exponent);
514expr_fn!(radians, num);
515expr_fn!(repeat, string n, "Repeats string the specified number of times.");
516expr_fn!(
517 replace,
518 string from to,
519 "Replaces all occurrences in string of substring from with substring to."
520);
521expr_fn!(
522 reverse,
523 string,
524 "Reverses the order of the characters in the string."
525);
526expr_fn!(right, string n, "Returns last n characters in the string, or when n is negative, returns all but first |n| characters.");
527expr_fn_vec!(round);
528expr_fn_vec!(rpad, "Extends the string to length length by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.");
529expr_fn_vec!(rtrim, "Removes the longest string containing only characters in characters (a space by default) from the end of string.");
530expr_fn!(sha224, input_arg1);
531expr_fn!(sha256, input_arg1);
532expr_fn!(sha384, input_arg1);
533expr_fn!(sha512, input_arg1);
534expr_fn!(signum, num);
535expr_fn!(sin, num);
536expr_fn!(sinh, num);
537expr_fn!(
538 split_part,
539 string delimiter index,
540 "Splits string at occurrences of delimiter and returns the n'th field (counting from one)."
541);
542expr_fn!(sqrt, num);
543expr_fn!(starts_with, string prefix, "Returns true if string starts with prefix.");
544expr_fn!(strpos, string substring, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)");
545expr_fn!(substr, string position);
546expr_fn!(substr_index, string delimiter count);
547expr_fn!(substring, string position length);
548expr_fn!(find_in_set, string string_list);
549expr_fn!(tan, num);
550expr_fn!(tanh, num);
551expr_fn!(
552 to_hex,
553 arg1,
554 "Converts the number to its equivalent hexadecimal representation."
555);
556expr_fn!(now);
557expr_fn_vec!(to_timestamp);
558expr_fn_vec!(to_timestamp_millis);
559expr_fn_vec!(to_timestamp_nanos);
560expr_fn_vec!(to_timestamp_micros);
561expr_fn_vec!(to_timestamp_seconds);
562expr_fn_vec!(to_unixtime);
563expr_fn!(current_date);
564expr_fn!(current_time);
565expr_fn!(date_part, part date);
566expr_fn!(date_trunc, part date);
567expr_fn!(date_bin, stride source origin);
568expr_fn!(make_date, year month day);
569
570expr_fn!(translate, string from to, "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.");
571expr_fn_vec!(trim, "Removes the longest string containing only characters in characters (a space by default) from the start, end, or both ends (BOTH is the default) of string.");
572expr_fn_vec!(trunc);
573expr_fn!(upper, arg1, "Converts the string to all upper case.");
574expr_fn!(uuid);
575expr_fn_vec!(r#struct); expr_fn_vec!(named_struct);
577expr_fn!(from_unixtime, unixtime);
578expr_fn!(arrow_typeof, arg_1);
579expr_fn!(arrow_cast, arg_1 datatype);
580expr_fn!(random);
581
582array_fn!(array_append, array element);
584array_fn!(array_to_string, array delimiter);
585array_fn!(array_dims, array);
586array_fn!(array_distinct, array);
587array_fn!(array_element, array element);
588array_fn!(array_empty, array);
589array_fn!(array_length, array);
590array_fn!(array_has, first_array second_array);
591array_fn!(array_has_all, first_array second_array);
592array_fn!(array_has_any, first_array second_array);
593array_fn!(array_positions, array element);
594array_fn!(array_ndims, array);
595array_fn!(array_prepend, element array);
596array_fn!(array_pop_back, array);
597array_fn!(array_pop_front, array);
598array_fn!(array_remove, array element);
599array_fn!(array_remove_n, array element max);
600array_fn!(array_remove_all, array element);
601array_fn!(array_repeat, element count);
602array_fn!(array_replace, array from to);
603array_fn!(array_replace_n, array from to max);
604array_fn!(array_replace_all, array from to);
605array_fn!(array_sort, array desc null_first);
606array_fn!(array_intersect, first_array second_array);
607array_fn!(array_union, array1 array2);
608array_fn!(array_except, first_array second_array);
609array_fn!(array_resize, array size value);
610array_fn!(cardinality, array);
611array_fn!(flatten, array);
612array_fn!(range, start stop step);
613
614aggregate_function!(array_agg);
615aggregate_function!(max);
616aggregate_function!(min);
617aggregate_function!(avg);
618aggregate_function!(sum);
619aggregate_function!(bit_and);
620aggregate_function!(bit_or);
621aggregate_function!(bit_xor);
622aggregate_function!(bool_and);
623aggregate_function!(bool_or);
624aggregate_function!(corr, y x);
625aggregate_function!(count);
626aggregate_function!(covar_samp, y x);
627aggregate_function!(covar_pop, y x);
628aggregate_function!(median);
629aggregate_function!(regr_slope, y x);
630aggregate_function!(regr_intercept, y x);
631aggregate_function!(regr_count, y x);
632aggregate_function!(regr_r2, y x);
633aggregate_function!(regr_avgx, y x);
634aggregate_function!(regr_avgy, y x);
635aggregate_function!(regr_sxx, y x);
636aggregate_function!(regr_syy, y x);
637aggregate_function!(regr_sxy, y x);
638aggregate_function!(stddev);
639aggregate_function!(stddev_pop);
640aggregate_function!(var_sample);
641aggregate_function!(var_pop);
642aggregate_function!(approx_distinct);
643aggregate_function!(approx_median);
644
645#[pyfunction]
650#[pyo3(signature = (expression, percentile, num_centroids=None, filter=None))]
651pub fn approx_percentile_cont(
652 expression: PyExpr,
653 percentile: f64,
654 num_centroids: Option<i64>, filter: Option<PyExpr>,
656) -> PyDataFusionResult<PyExpr> {
657 let args = if let Some(num_centroids) = num_centroids {
658 vec![expression.expr, lit(percentile), lit(num_centroids)]
659 } else {
660 vec![expression.expr, lit(percentile)]
661 };
662 let udaf = functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf();
663 let agg_fn = udaf.call(args);
664
665 add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
666}
667
668#[pyfunction]
669#[pyo3(signature = (expression, weight, percentile, filter=None))]
670pub fn approx_percentile_cont_with_weight(
671 expression: PyExpr,
672 weight: PyExpr,
673 percentile: f64,
674 filter: Option<PyExpr>,
675) -> PyDataFusionResult<PyExpr> {
676 let agg_fn = functions_aggregate::expr_fn::approx_percentile_cont_with_weight(
677 expression.expr,
678 weight.expr,
679 lit(percentile),
680 );
681
682 add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
683}
684
685#[pyfunction]
688#[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))]
689pub fn last_value(
690 expr: PyExpr,
691 distinct: Option<bool>,
692 filter: Option<PyExpr>,
693 order_by: Option<Vec<PySortExpr>>,
694 null_treatment: Option<NullTreatment>,
695) -> PyDataFusionResult<PyExpr> {
696 let agg_fn = functions_aggregate::expr_fn::last_value(expr.expr, None);
698
699 add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
700}
701#[pyfunction]
704#[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))]
705pub fn first_value(
706 expr: PyExpr,
707 distinct: Option<bool>,
708 filter: Option<PyExpr>,
709 order_by: Option<Vec<PySortExpr>>,
710 null_treatment: Option<NullTreatment>,
711) -> PyDataFusionResult<PyExpr> {
712 let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
714
715 add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
716}
717
718#[pyfunction]
720#[pyo3(signature = (expr, n, distinct=None, filter=None, order_by=None, null_treatment=None))]
721pub fn nth_value(
722 expr: PyExpr,
723 n: i64,
724 distinct: Option<bool>,
725 filter: Option<PyExpr>,
726 order_by: Option<Vec<PySortExpr>>,
727 null_treatment: Option<NullTreatment>,
728) -> PyDataFusionResult<PyExpr> {
729 let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(expr.expr, n, vec![]);
730 add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
731}
732
733#[pyfunction]
735#[pyo3(signature = (expr, delimiter, distinct=None, filter=None, order_by=None, null_treatment=None))]
736pub fn string_agg(
737 expr: PyExpr,
738 delimiter: String,
739 distinct: Option<bool>,
740 filter: Option<PyExpr>,
741 order_by: Option<Vec<PySortExpr>>,
742 null_treatment: Option<NullTreatment>,
743) -> PyDataFusionResult<PyExpr> {
744 let agg_fn = datafusion::functions_aggregate::string_agg::string_agg(expr.expr, lit(delimiter));
745 add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment)
746}
747
748pub(crate) fn add_builder_fns_to_window(
749 window_fn: Expr,
750 partition_by: Option<Vec<PyExpr>>,
751 window_frame: Option<PyWindowFrame>,
752 order_by: Option<Vec<PySortExpr>>,
753 null_treatment: Option<NullTreatment>,
754) -> PyDataFusionResult<PyExpr> {
755 let null_treatment = null_treatment.map(|n| n.into());
756 let mut builder = window_fn.null_treatment(null_treatment);
757
758 if let Some(partition_cols) = partition_by {
759 builder = builder.partition_by(
760 partition_cols
761 .into_iter()
762 .map(|col| col.clone().into())
763 .collect(),
764 );
765 }
766
767 if let Some(order_by_cols) = order_by {
768 let order_by_cols = to_sort_expressions(order_by_cols);
769 builder = builder.order_by(order_by_cols);
770 }
771
772 if let Some(window_frame) = window_frame {
773 builder = builder.window_frame(window_frame.into());
774 }
775
776 Ok(builder.build().map(|e| e.into())?)
777}
778
779#[pyfunction]
780#[pyo3(signature = (arg, shift_offset, default_value=None, partition_by=None, order_by=None))]
781pub fn lead(
782 arg: PyExpr,
783 shift_offset: i64,
784 default_value: Option<PyScalarValue>,
785 partition_by: Option<Vec<PyExpr>>,
786 order_by: Option<Vec<PySortExpr>>,
787) -> PyDataFusionResult<PyExpr> {
788 let default_value = default_value.map(|v| v.into());
789 let window_fn = functions_window::expr_fn::lead(arg.expr, Some(shift_offset), default_value);
790
791 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
792}
793
794#[pyfunction]
795#[pyo3(signature = (arg, shift_offset, default_value=None, partition_by=None, order_by=None))]
796pub fn lag(
797 arg: PyExpr,
798 shift_offset: i64,
799 default_value: Option<PyScalarValue>,
800 partition_by: Option<Vec<PyExpr>>,
801 order_by: Option<Vec<PySortExpr>>,
802) -> PyDataFusionResult<PyExpr> {
803 let default_value = default_value.map(|v| v.into());
804 let window_fn = functions_window::expr_fn::lag(arg.expr, Some(shift_offset), default_value);
805
806 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
807}
808
809#[pyfunction]
810#[pyo3(signature = (partition_by=None, order_by=None))]
811pub fn row_number(
812 partition_by: Option<Vec<PyExpr>>,
813 order_by: Option<Vec<PySortExpr>>,
814) -> PyDataFusionResult<PyExpr> {
815 let window_fn = functions_window::expr_fn::row_number();
816
817 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
818}
819
820#[pyfunction]
821#[pyo3(signature = (partition_by=None, order_by=None))]
822pub fn rank(
823 partition_by: Option<Vec<PyExpr>>,
824 order_by: Option<Vec<PySortExpr>>,
825) -> PyDataFusionResult<PyExpr> {
826 let window_fn = functions_window::expr_fn::rank();
827
828 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
829}
830
831#[pyfunction]
832#[pyo3(signature = (partition_by=None, order_by=None))]
833pub fn dense_rank(
834 partition_by: Option<Vec<PyExpr>>,
835 order_by: Option<Vec<PySortExpr>>,
836) -> PyDataFusionResult<PyExpr> {
837 let window_fn = functions_window::expr_fn::dense_rank();
838
839 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
840}
841
842#[pyfunction]
843#[pyo3(signature = (partition_by=None, order_by=None))]
844pub fn percent_rank(
845 partition_by: Option<Vec<PyExpr>>,
846 order_by: Option<Vec<PySortExpr>>,
847) -> PyDataFusionResult<PyExpr> {
848 let window_fn = functions_window::expr_fn::percent_rank();
849
850 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
851}
852
853#[pyfunction]
854#[pyo3(signature = (partition_by=None, order_by=None))]
855pub fn cume_dist(
856 partition_by: Option<Vec<PyExpr>>,
857 order_by: Option<Vec<PySortExpr>>,
858) -> PyDataFusionResult<PyExpr> {
859 let window_fn = functions_window::expr_fn::cume_dist();
860
861 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
862}
863
864#[pyfunction]
865#[pyo3(signature = (arg, partition_by=None, order_by=None))]
866pub fn ntile(
867 arg: PyExpr,
868 partition_by: Option<Vec<PyExpr>>,
869 order_by: Option<Vec<PySortExpr>>,
870) -> PyDataFusionResult<PyExpr> {
871 let window_fn = functions_window::expr_fn::ntile(arg.into());
872
873 add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
874}
875
876pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
877 m.add_wrapped(wrap_pyfunction!(abs))?;
878 m.add_wrapped(wrap_pyfunction!(acos))?;
879 m.add_wrapped(wrap_pyfunction!(acosh))?;
880 m.add_wrapped(wrap_pyfunction!(approx_distinct))?;
881 m.add_wrapped(wrap_pyfunction!(alias))?;
882 m.add_wrapped(wrap_pyfunction!(approx_median))?;
883 m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
884 m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
885 m.add_wrapped(wrap_pyfunction!(range))?;
886 m.add_wrapped(wrap_pyfunction!(array_agg))?;
887 m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
888 m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
889 m.add_wrapped(wrap_pyfunction!(ascii))?;
890 m.add_wrapped(wrap_pyfunction!(asin))?;
891 m.add_wrapped(wrap_pyfunction!(asinh))?;
892 m.add_wrapped(wrap_pyfunction!(atan))?;
893 m.add_wrapped(wrap_pyfunction!(atanh))?;
894 m.add_wrapped(wrap_pyfunction!(atan2))?;
895 m.add_wrapped(wrap_pyfunction!(avg))?;
896 m.add_wrapped(wrap_pyfunction!(bit_length))?;
897 m.add_wrapped(wrap_pyfunction!(btrim))?;
898 m.add_wrapped(wrap_pyfunction!(cbrt))?;
899 m.add_wrapped(wrap_pyfunction!(ceil))?;
900 m.add_wrapped(wrap_pyfunction!(character_length))?;
901 m.add_wrapped(wrap_pyfunction!(chr))?;
902 m.add_wrapped(wrap_pyfunction!(char_length))?;
903 m.add_wrapped(wrap_pyfunction!(coalesce))?;
904 m.add_wrapped(wrap_pyfunction!(case))?;
905 m.add_wrapped(wrap_pyfunction!(when))?;
906 m.add_wrapped(wrap_pyfunction!(col))?;
907 m.add_wrapped(wrap_pyfunction!(concat_ws))?;
908 m.add_wrapped(wrap_pyfunction!(concat))?;
909 m.add_wrapped(wrap_pyfunction!(corr))?;
910 m.add_wrapped(wrap_pyfunction!(cos))?;
911 m.add_wrapped(wrap_pyfunction!(cosh))?;
912 m.add_wrapped(wrap_pyfunction!(cot))?;
913 m.add_wrapped(wrap_pyfunction!(count))?;
914 m.add_wrapped(wrap_pyfunction!(covar_pop))?;
915 m.add_wrapped(wrap_pyfunction!(covar_samp))?;
916 m.add_wrapped(wrap_pyfunction!(current_date))?;
917 m.add_wrapped(wrap_pyfunction!(current_time))?;
918 m.add_wrapped(wrap_pyfunction!(degrees))?;
919 m.add_wrapped(wrap_pyfunction!(date_bin))?;
920 m.add_wrapped(wrap_pyfunction!(date_part))?;
921 m.add_wrapped(wrap_pyfunction!(date_trunc))?;
922 m.add_wrapped(wrap_pyfunction!(make_date))?;
923 m.add_wrapped(wrap_pyfunction!(digest))?;
924 m.add_wrapped(wrap_pyfunction!(ends_with))?;
925 m.add_wrapped(wrap_pyfunction!(exp))?;
926 m.add_wrapped(wrap_pyfunction!(factorial))?;
927 m.add_wrapped(wrap_pyfunction!(floor))?;
928 m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
929 m.add_wrapped(wrap_pyfunction!(gcd))?;
930 m.add_wrapped(wrap_pyfunction!(in_list))?;
932 m.add_wrapped(wrap_pyfunction!(initcap))?;
933 m.add_wrapped(wrap_pyfunction!(isnan))?;
934 m.add_wrapped(wrap_pyfunction!(iszero))?;
935 m.add_wrapped(wrap_pyfunction!(levenshtein))?;
936 m.add_wrapped(wrap_pyfunction!(lcm))?;
937 m.add_wrapped(wrap_pyfunction!(left))?;
938 m.add_wrapped(wrap_pyfunction!(length))?;
939 m.add_wrapped(wrap_pyfunction!(ln))?;
940 m.add_wrapped(wrap_pyfunction!(log))?;
941 m.add_wrapped(wrap_pyfunction!(log10))?;
942 m.add_wrapped(wrap_pyfunction!(log2))?;
943 m.add_wrapped(wrap_pyfunction!(lower))?;
944 m.add_wrapped(wrap_pyfunction!(lpad))?;
945 m.add_wrapped(wrap_pyfunction!(ltrim))?;
946 m.add_wrapped(wrap_pyfunction!(max))?;
947 m.add_wrapped(wrap_pyfunction!(make_array))?;
948 m.add_wrapped(wrap_pyfunction!(md5))?;
949 m.add_wrapped(wrap_pyfunction!(median))?;
950 m.add_wrapped(wrap_pyfunction!(min))?;
951 m.add_wrapped(wrap_pyfunction!(named_struct))?;
952 m.add_wrapped(wrap_pyfunction!(nanvl))?;
953 m.add_wrapped(wrap_pyfunction!(nvl))?;
954 m.add_wrapped(wrap_pyfunction!(now))?;
955 m.add_wrapped(wrap_pyfunction!(nullif))?;
956 m.add_wrapped(wrap_pyfunction!(octet_length))?;
957 m.add_wrapped(wrap_pyfunction!(order_by))?;
958 m.add_wrapped(wrap_pyfunction!(overlay))?;
959 m.add_wrapped(wrap_pyfunction!(pi))?;
960 m.add_wrapped(wrap_pyfunction!(power))?;
961 m.add_wrapped(wrap_pyfunction!(radians))?;
962 m.add_wrapped(wrap_pyfunction!(random))?;
963 m.add_wrapped(wrap_pyfunction!(regexp_count))?;
964 m.add_wrapped(wrap_pyfunction!(regexp_like))?;
965 m.add_wrapped(wrap_pyfunction!(regexp_match))?;
966 m.add_wrapped(wrap_pyfunction!(regexp_replace))?;
967 m.add_wrapped(wrap_pyfunction!(repeat))?;
968 m.add_wrapped(wrap_pyfunction!(replace))?;
969 m.add_wrapped(wrap_pyfunction!(reverse))?;
970 m.add_wrapped(wrap_pyfunction!(right))?;
971 m.add_wrapped(wrap_pyfunction!(round))?;
972 m.add_wrapped(wrap_pyfunction!(rpad))?;
973 m.add_wrapped(wrap_pyfunction!(rtrim))?;
974 m.add_wrapped(wrap_pyfunction!(sha224))?;
975 m.add_wrapped(wrap_pyfunction!(sha256))?;
976 m.add_wrapped(wrap_pyfunction!(sha384))?;
977 m.add_wrapped(wrap_pyfunction!(sha512))?;
978 m.add_wrapped(wrap_pyfunction!(signum))?;
979 m.add_wrapped(wrap_pyfunction!(sin))?;
980 m.add_wrapped(wrap_pyfunction!(sinh))?;
981 m.add_wrapped(wrap_pyfunction!(split_part))?;
982 m.add_wrapped(wrap_pyfunction!(sqrt))?;
983 m.add_wrapped(wrap_pyfunction!(starts_with))?;
984 m.add_wrapped(wrap_pyfunction!(stddev))?;
985 m.add_wrapped(wrap_pyfunction!(stddev_pop))?;
986 m.add_wrapped(wrap_pyfunction!(string_agg))?;
987 m.add_wrapped(wrap_pyfunction!(strpos))?;
988 m.add_wrapped(wrap_pyfunction!(r#struct))?; m.add_wrapped(wrap_pyfunction!(substr))?;
990 m.add_wrapped(wrap_pyfunction!(substr_index))?;
991 m.add_wrapped(wrap_pyfunction!(substring))?;
992 m.add_wrapped(wrap_pyfunction!(find_in_set))?;
993 m.add_wrapped(wrap_pyfunction!(sum))?;
994 m.add_wrapped(wrap_pyfunction!(tan))?;
995 m.add_wrapped(wrap_pyfunction!(tanh))?;
996 m.add_wrapped(wrap_pyfunction!(to_hex))?;
997 m.add_wrapped(wrap_pyfunction!(to_timestamp))?;
998 m.add_wrapped(wrap_pyfunction!(to_timestamp_millis))?;
999 m.add_wrapped(wrap_pyfunction!(to_timestamp_nanos))?;
1000 m.add_wrapped(wrap_pyfunction!(to_timestamp_micros))?;
1001 m.add_wrapped(wrap_pyfunction!(to_timestamp_seconds))?;
1002 m.add_wrapped(wrap_pyfunction!(to_unixtime))?;
1003 m.add_wrapped(wrap_pyfunction!(translate))?;
1004 m.add_wrapped(wrap_pyfunction!(trim))?;
1005 m.add_wrapped(wrap_pyfunction!(trunc))?;
1006 m.add_wrapped(wrap_pyfunction!(upper))?;
1007 m.add_wrapped(wrap_pyfunction!(self::uuid))?; m.add_wrapped(wrap_pyfunction!(var_pop))?;
1009 m.add_wrapped(wrap_pyfunction!(var_sample))?;
1010 m.add_wrapped(wrap_pyfunction!(window))?;
1011 m.add_wrapped(wrap_pyfunction!(regr_avgx))?;
1012 m.add_wrapped(wrap_pyfunction!(regr_avgy))?;
1013 m.add_wrapped(wrap_pyfunction!(regr_count))?;
1014 m.add_wrapped(wrap_pyfunction!(regr_intercept))?;
1015 m.add_wrapped(wrap_pyfunction!(regr_r2))?;
1016 m.add_wrapped(wrap_pyfunction!(regr_slope))?;
1017 m.add_wrapped(wrap_pyfunction!(regr_sxx))?;
1018 m.add_wrapped(wrap_pyfunction!(regr_sxy))?;
1019 m.add_wrapped(wrap_pyfunction!(regr_syy))?;
1020 m.add_wrapped(wrap_pyfunction!(first_value))?;
1021 m.add_wrapped(wrap_pyfunction!(last_value))?;
1022 m.add_wrapped(wrap_pyfunction!(nth_value))?;
1023 m.add_wrapped(wrap_pyfunction!(bit_and))?;
1024 m.add_wrapped(wrap_pyfunction!(bit_or))?;
1025 m.add_wrapped(wrap_pyfunction!(bit_xor))?;
1026 m.add_wrapped(wrap_pyfunction!(bool_and))?;
1027 m.add_wrapped(wrap_pyfunction!(bool_or))?;
1028
1029 m.add_wrapped(wrap_pyfunction!(encode))?;
1031 m.add_wrapped(wrap_pyfunction!(decode))?;
1032
1033 m.add_wrapped(wrap_pyfunction!(array_append))?;
1035 m.add_wrapped(wrap_pyfunction!(array_concat))?;
1036 m.add_wrapped(wrap_pyfunction!(array_cat))?;
1037 m.add_wrapped(wrap_pyfunction!(array_dims))?;
1038 m.add_wrapped(wrap_pyfunction!(array_distinct))?;
1039 m.add_wrapped(wrap_pyfunction!(array_element))?;
1040 m.add_wrapped(wrap_pyfunction!(array_empty))?;
1041 m.add_wrapped(wrap_pyfunction!(array_length))?;
1042 m.add_wrapped(wrap_pyfunction!(array_has))?;
1043 m.add_wrapped(wrap_pyfunction!(array_has_all))?;
1044 m.add_wrapped(wrap_pyfunction!(array_has_any))?;
1045 m.add_wrapped(wrap_pyfunction!(array_position))?;
1046 m.add_wrapped(wrap_pyfunction!(array_positions))?;
1047 m.add_wrapped(wrap_pyfunction!(array_to_string))?;
1048 m.add_wrapped(wrap_pyfunction!(array_intersect))?;
1049 m.add_wrapped(wrap_pyfunction!(array_union))?;
1050 m.add_wrapped(wrap_pyfunction!(array_except))?;
1051 m.add_wrapped(wrap_pyfunction!(array_resize))?;
1052 m.add_wrapped(wrap_pyfunction!(array_ndims))?;
1053 m.add_wrapped(wrap_pyfunction!(array_prepend))?;
1054 m.add_wrapped(wrap_pyfunction!(array_pop_back))?;
1055 m.add_wrapped(wrap_pyfunction!(array_pop_front))?;
1056 m.add_wrapped(wrap_pyfunction!(array_remove))?;
1057 m.add_wrapped(wrap_pyfunction!(array_remove_n))?;
1058 m.add_wrapped(wrap_pyfunction!(array_remove_all))?;
1059 m.add_wrapped(wrap_pyfunction!(array_repeat))?;
1060 m.add_wrapped(wrap_pyfunction!(array_replace))?;
1061 m.add_wrapped(wrap_pyfunction!(array_replace_n))?;
1062 m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
1063 m.add_wrapped(wrap_pyfunction!(array_sort))?;
1064 m.add_wrapped(wrap_pyfunction!(array_slice))?;
1065 m.add_wrapped(wrap_pyfunction!(flatten))?;
1066 m.add_wrapped(wrap_pyfunction!(cardinality))?;
1067
1068 m.add_wrapped(wrap_pyfunction!(lead))?;
1070 m.add_wrapped(wrap_pyfunction!(lag))?;
1071 m.add_wrapped(wrap_pyfunction!(rank))?;
1072 m.add_wrapped(wrap_pyfunction!(row_number))?;
1073 m.add_wrapped(wrap_pyfunction!(dense_rank))?;
1074 m.add_wrapped(wrap_pyfunction!(percent_rank))?;
1075 m.add_wrapped(wrap_pyfunction!(cume_dist))?;
1076 m.add_wrapped(wrap_pyfunction!(ntile))?;
1077
1078 Ok(())
1079}