polars_plan/dsl/functions/
correlation.rs

1use super::*;
2
3/// Compute the covariance between two columns.
4pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
5    let function = FunctionExpr::Correlation {
6        method: CorrelationMethod::Covariance(ddof),
7    };
8    a.map_binary(function, b)
9}
10
11/// Compute the pearson correlation between two columns.
12pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
13    let function = FunctionExpr::Correlation {
14        method: CorrelationMethod::Pearson,
15    };
16    a.map_binary(function, b)
17}
18
19/// Compute the spearman rank correlation between two columns.
20/// Missing data will be excluded from the computation.
21/// # Arguments
22/// * propagate_nans
23///   If `true` any `NaN` encountered will lead to `NaN` in the output.
24///   If to `false` then `NaN` are regarded as larger than any finite number
25///   and thus lead to the highest rank.
26#[cfg(all(feature = "rank", feature = "propagate_nans"))]
27pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {
28    let function = FunctionExpr::Correlation {
29        method: CorrelationMethod::SpearmanRank(propagate_nans),
30    };
31    a.map_binary(function, b)
32}
33
34#[cfg(all(feature = "rolling_window", feature = "cov"))]
35fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr {
36    // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804
37    let rolling_options = RollingOptionsFixedWindow {
38        window_size: options.window_size as usize,
39        min_periods: options.min_periods as usize,
40        ..Default::default()
41    };
42
43    Expr::Function {
44        input: vec![x, y],
45        function: FunctionExpr::RollingExpr(RollingFunction::CorrCov {
46            rolling_options,
47            corr_cov_options: options,
48            is_corr,
49        }),
50    }
51}
52
53#[cfg(all(feature = "rolling_window", feature = "cov"))]
54pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
55    dispatch_corr_cov(x, y, options, true)
56}
57
58#[cfg(all(feature = "rolling_window", feature = "cov"))]
59pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
60    dispatch_corr_cov(x, y, options, false)
61}