nu_command/math/
variance.rs

1use crate::math::utils::run_with_function;
2use nu_engine::command_prelude::*;
3
4#[derive(Clone)]
5pub struct MathVariance;
6
7impl Command for MathVariance {
8    fn name(&self) -> &str {
9        "math variance"
10    }
11
12    fn signature(&self) -> Signature {
13        Signature::build("math variance")
14            .input_output_types(vec![
15                (Type::List(Box::new(Type::Number)), Type::Number),
16                (Type::Range, Type::Number),
17                (Type::table(), Type::record()),
18                (Type::record(), Type::record()),
19            ])
20            .switch(
21                "sample",
22                "calculate sample variance (i.e. using N-1 as the denominator)",
23                Some('s'),
24            )
25            .allow_variants_without_examples(true)
26            .category(Category::Math)
27    }
28
29    fn description(&self) -> &str {
30        "Returns the variance of a list of numbers or of each column in a table."
31    }
32
33    fn search_terms(&self) -> Vec<&str> {
34        vec!["deviation", "dispersion", "variation", "statistics"]
35    }
36
37    fn is_const(&self) -> bool {
38        true
39    }
40
41    fn run(
42        &self,
43        engine_state: &EngineState,
44        stack: &mut Stack,
45        call: &Call,
46        input: PipelineData,
47    ) -> Result<PipelineData, ShellError> {
48        let sample = call.has_flag(engine_state, stack, "sample")?;
49        let name = call.head;
50        let span = input.span().unwrap_or(name);
51        let input: PipelineData = match input.try_expand_range() {
52            Err(_) => {
53                return Err(ShellError::IncorrectValue {
54                    msg: "Range must be bounded".to_string(),
55                    val_span: span,
56                    call_span: name,
57                });
58            }
59            Ok(val) => val,
60        };
61        run_with_function(call, input, compute_variance(sample))
62    }
63
64    fn run_const(
65        &self,
66        working_set: &StateWorkingSet,
67        call: &Call,
68        input: PipelineData,
69    ) -> Result<PipelineData, ShellError> {
70        let sample = call.has_flag_const(working_set, "sample")?;
71        let name = call.head;
72        let span = input.span().unwrap_or(name);
73        let input: PipelineData = match input.try_expand_range() {
74            Err(_) => {
75                return Err(ShellError::IncorrectValue {
76                    msg: "Range must be bounded".to_string(),
77                    val_span: span,
78                    call_span: name,
79                });
80            }
81            Ok(val) => val,
82        };
83        run_with_function(call, input, compute_variance(sample))
84    }
85
86    fn examples(&self) -> Vec<Example> {
87        vec![
88            Example {
89                description: "Get the variance of a list of numbers",
90                example: "[1 2 3 4 5] | math variance",
91                result: Some(Value::test_float(2.0)),
92            },
93            Example {
94                description: "Get the sample variance of a list of numbers",
95                example: "[1 2 3 4 5] | math variance --sample",
96                result: Some(Value::test_float(2.5)),
97            },
98            Example {
99                description: "Compute the variance of each column in a table",
100                example: "[[a b]; [1 2] [3 4]] | math variance",
101                result: Some(Value::test_record(record! {
102                    "a" => Value::test_int(1),
103                    "b" => Value::test_int(1),
104                })),
105            },
106        ]
107    }
108}
109
110fn sum_of_squares(values: &[Value], span: Span) -> Result<Value, ShellError> {
111    let n = Value::int(values.len() as i64, span);
112    let mut sum_x = Value::int(0, span);
113    let mut sum_x2 = Value::int(0, span);
114    for value in values {
115        let v = match &value {
116            Value::Int { .. } | Value::Float { .. } => Ok(value),
117            Value::Error { error, .. } => Err(*error.clone()),
118            other => Err(ShellError::UnsupportedInput {
119                msg: format!(
120                    "Attempted to compute the sum of squares of a non-int, non-float value '{}' with a type of `{}`.",
121                    other.coerce_string()?,
122                    other.get_type()
123                ),
124                input: "value originates from here".into(),
125                msg_span: span,
126                input_span: value.span(),
127            }),
128        }?;
129        let v_squared = &v.mul(span, v, span)?;
130        sum_x2 = sum_x2.add(span, v_squared, span)?;
131        sum_x = sum_x.add(span, v, span)?;
132    }
133
134    let sum_x_squared = sum_x.mul(span, &sum_x, span)?;
135    let sum_x_squared_div_n = sum_x_squared.div(span, &n, span)?;
136
137    let ss = sum_x2.sub(span, &sum_x_squared_div_n, span)?;
138
139    Ok(ss)
140}
141
142pub fn compute_variance(
143    sample: bool,
144) -> impl Fn(&[Value], Span, Span) -> Result<Value, ShellError> {
145    move |values: &[Value], span: Span, head: Span| {
146        let n = if sample {
147            values.len() - 1
148        } else {
149            values.len()
150        };
151        // sum_of_squares() needs the span of the original value, not the call head.
152        let ss = sum_of_squares(values, span)?;
153        let n = Value::int(n as i64, head);
154        ss.div(head, &n, head)
155    }
156}
157
158#[cfg(test)]
159mod test {
160    use super::*;
161
162    #[test]
163    fn test_examples() {
164        use crate::test_examples;
165
166        test_examples(MathVariance {})
167    }
168}