nu_plugin_polars 0.112.0

Nushell dataframe plugin commands based on polars.
use nu_plugin::{EngineInterface, EvaluatedCall, PluginCommand};
use nu_protocol::shell_error::generic::GenericError;
use nu_protocol::{
    Category, Example, LabeledError, PipelineData, ShellError, Signature, Span, SyntaxShape, Value,
};

use polars::{
    df,
    frame::DataFrame,
    prelude::{Expr, PlSmallStr, Selector, element},
};

use crate::{
    PolarsPlugin,
    command::required_flag,
    values::{CustomValueSupport, NuExpression, NuLazyFrame, NuSelector, PolarsPluginType},
};

use crate::values::NuDataFrame;

#[derive(Clone)]
pub struct PivotDF;

impl PluginCommand for PivotDF {
    type Plugin = PolarsPlugin;

    fn name(&self) -> &str {
        "polars pivot"
    }

    fn description(&self) -> &str {
        "Pivot a DataFrame from long to wide format."
    }

    fn signature(&self) -> Signature {
        Signature::build(self.name())
            .required_named(
                "on",
                SyntaxShape::Any,
                "Column names for pivoting.",
                Some('o'),
            )
            .required_named(
                "on-cols",
                SyntaxShape::Any,
                "column names used as value columns",
                Some('c'),
            )
            .named(
                "index",
                SyntaxShape::Any,
                "Selector or column names for indexes.",
                Some('i'),
            )
            .named(
                "values",
                SyntaxShape::Any,
                "Selector or column names used as value columns.",
                None,
            )
            .named(
                "aggregate",
                SyntaxShape::Any,
                "Aggregation to apply when pivoting. The following are supported: first, sum, min, max, mean, median, count, last, or a custom expression.",
                Some('a'),
            )
            .named(
                "separator",
                SyntaxShape::String,
                "Delimiter in generated column names in case of multiple `values` columns (default '_').",
                Some('p'),
            )
            .switch(
                "maintain-order",
                "Maintain Order.",
                None,
            )
            .switch(
                "streamable",
                "Whether or not to use the polars streaming engine. Only valid for lazy dataframes",
                Some('t'),
            )
            .switch(
                "stable",
                "Perform a stable pivot.",
                None,
            )
            .input_output_types(vec![
                (
                    PolarsPluginType::NuDataFrame.into(),
                    PolarsPluginType::NuDataFrame.into(),
                ),
                (
                    PolarsPluginType::NuLazyFrame.into(),
                    PolarsPluginType::NuLazyFrame.into(),
                ),
            ])
            .category(Category::Custom("dataframe".into()))
    }

    fn examples(&self) -> Vec<Example<'_>> {
        vec![
            Example {
                example: r#"{
        "name": ["Cady", "Cady", "Karen", "Karen"],
        "subject": ["maths", "physics", "maths", "physics"],
        "test_1": [98, 99, 61, 58],
        "test_2": [100, 100, 60, 60],
    } | 
    polars into-df --as-columns | 
    polars pivot --on subject --on-cols [maths physics] --index name --values test_1 |
    polars sort-by name maths physics |
    polars collect"#,
                description: "Given a set of test scores, reshape so we have one row per student, with different subjects as columns, and their `test_1` scores as values",
                result: Some(
                    NuDataFrame::from(
                        df!(
                            "name" => ["Cady", "Karen"],
                            "maths" => [98, 61],
                            "physics" => [99, 58],
                        )
                        .expect("Could not create test datafarme"),
                    )
                    .into_value(Span::test_data()),
                ),
            },
            Example {
                example: r#"{
        "name": ["Cady", "Cady", "Karen", "Karen"],
        "subject": ["maths", "physics", "maths", "physics"],
        "test_1": [98, 99, 61, 58],
        "test_2": [100, 100, 60, 60],
    } |
    polars into-df --as-columns |
    polars pivot --on subject --on-cols [maths physics] --index name --values (polars selector starts-with test) |
    polars sort-by name test_1_maths test_1_physics test_2_maths test_2_physics |
    polars collect"#,
                description: "Given a set of test scores, reshape so we have one row per student, utilize a selector for the values come to include all test scores",
                result: Some(
                    NuDataFrame::from(
                        df!(
                            "name" => ["Cady", "Karen"],
                            "test_1_maths" => [98, 61],
                            "test_1_physics" => [99, 58],
                            "test_2_maths" => [100, 60],
                            "test_2_physics" => [100, 60],
                        )
                        .expect("Could not create test datafarme"),
                    )
                    .into_value(Span::test_data()),
                ),
            },
            Example {
                example: r#"{
        "ix": [1, 1, 2, 2, 1, 2],
        "col": ["a", "a", "a", "a", "b", "b"],
        "foo": [0, 1, 2, 2, 7, 1],
        "bar": [0, 2, 0, 0, 9, 4],
    } |
    polars into-df --as-columns |
    polars pivot --on col --on-cols [a b] --index ix --aggregate sum |
    polars sort-by ix foo_a foo_b bar_a bar_b |
    polars collect"#,
                description: "Given a DataFrame with duplicate entries for the pivot columns, use the `aggregate` flag to specify how to aggregate values for those duplicates. In this example, we sum the `foo` and `bar` values for rows with the same `ix` and `col` values.",
                result: Some(
                    NuDataFrame::from(
                        df!(
                            "ix" => [1, 2],
                            "foo_a" => [1, 4],
                            "foo_b" => [7, 1],
                            "bar_a" => [2, 0],
                            "bar_b" => [9, 4],
                        )
                        .expect("Could not create test datafarme"),
                    )
                    .into_value(Span::test_data()),
                ),
            },
        ]
    }

    fn run(
        &self,
        plugin: &Self::Plugin,
        engine: &EngineInterface,
        call: &EvaluatedCall,
        mut input: PipelineData,
    ) -> Result<PipelineData, LabeledError> {
        let metadata = input.take_metadata();
        let lazy = NuLazyFrame::try_from_pipeline_coerce(plugin, input, call.head)?;
        command_lazy(plugin, engine, call, lazy)
            .map_err(LabeledError::from)
            .map(|pd| pd.set_metadata(metadata))
    }
}

fn command_lazy(
    plugin: &PolarsPlugin,
    engine: &EngineInterface,
    call: &EvaluatedCall,
    lazy: NuLazyFrame,
) -> Result<PipelineData, ShellError> {
    let on: Selector = call
        .get_flag::<Value>("on")?
        .map(|ref v| NuSelector::try_from_value(plugin, v))
        .transpose()?
        .ok_or(required_flag("on", call.head))?
        .into_polars();

    let on_columns: DataFrame = call
        .get_flag::<Value>("on-cols")?
        .map(|ref v| NuDataFrame::try_from_value(plugin, v))
        .transpose()?
        .ok_or(required_flag("on-cols", call.head))?
        .to_polars();

    let index: Option<Selector> = call
        .get_flag::<Value>("index")?
        .map(|ref v| NuSelector::try_from_value(plugin, v))
        .transpose()?
        .map(|s| s.into_polars());

    let values: Option<Selector> = call
        .get_flag::<Value>("values")?
        .map(|ref v| NuSelector::try_from_value(plugin, v))
        .transpose()?
        .map(|s| s.into_polars());

    let agg: Expr = call
        .get_flag::<Value>("aggregate")?
        .map(|val| pivot_agg_for_value(plugin, val))
        .transpose()?
        .unwrap_or(element().item(true));

    let maintain_order = call.has_flag("maintain-order")?;

    let separator: PlSmallStr = call
        .get_flag::<String>("separator")?
        .map(PlSmallStr::from)
        .unwrap_or_else(|| PlSmallStr::from("_"));

    if index.is_none() && values.is_none() {
        return Err(ShellError::Generic(GenericError::new(
            "`pivot` needs either `--index or `--values` needs to be specified",
            "",
            call.head,
        )));
    }

    let index_selector = if let Some(index) = index.clone() {
        index
    } else {
        Selector::Wildcard - on.clone() - values.clone().unwrap_or_else(|| Selector::Empty)
    };

    let values_selector = if let Some(values) = values {
        values
    } else {
        Selector::Wildcard - on.clone() - index.unwrap_or_else(|| Selector::Empty)
    };

    let result: NuLazyFrame = lazy
        .to_polars()
        .pivot(
            on,
            on_columns.into(),
            index_selector,
            values_selector,
            agg,
            maintain_order,
            separator,
        )
        .into();

    result.to_pipeline_data(plugin, engine, call.head)
}

fn pivot_agg_for_value(plugin: &PolarsPlugin, agg: Value) -> Result<Expr, ShellError> {
    match agg {
        Value::String { val, .. } => match val.as_str() {
            "first" => Ok(element().first()),
            "sum" => Ok(element().sum()),
            "min" => Ok(element().min()),
            "max" => Ok(element().max()),
            "mean" => Ok(element().mean()),
            "median" => Ok(element().median()),
            "length" | "len" | "count" => Ok(element().len()),
            "last" => Ok(element().last()),
            "element" | "item" => Ok(element().item(true)),
            s => Err(ShellError::Generic(
                GenericError::new(
                    format!("{s} is not a valid aggregation"),
                    "",
                    Span::unknown(),
                )
                .with_help(
                    "Use one of the following: first, sum, min, max, mean, median, count, last",
                ),
            )),
        },
        Value::Custom { .. } => {
            let expr = NuExpression::try_from_value(plugin, &agg)?;
            Ok(expr.into_polars())
        }
        _ => Err(ShellError::Generic(GenericError::new(
            "Aggregation must be a string or expression",
            "",
            agg.span(),
        ))),
    }
}

#[cfg(test)]
mod test {
    use crate::test::test_polars_plugin_command;

    use super::*;

    #[test]
    fn test_examples() -> Result<(), ShellError> {
        test_polars_plugin_command(&PivotDF)
    }
}