nu_command/math/
variance.rs

1use crate::math::utils::run_with_function;
2use nu_engine::command_prelude::*;
3
4#[derive(Clone)]
5pub struct MathVariance;
6
7impl Command for MathVariance {
8    fn name(&self) -> &str {
9        "math variance"
10    }
11
12    fn signature(&self) -> Signature {
13        Signature::build("math variance")
14            .input_output_types(vec![
15                (Type::List(Box::new(Type::Number)), Type::Number),
16                (Type::Range, Type::Number),
17                (Type::table(), Type::record()),
18                (Type::record(), Type::record()),
19            ])
20            .switch(
21                "sample",
22                "calculate sample variance (i.e. using N-1 as the denominator)",
23                Some('s'),
24            )
25            .allow_variants_without_examples(true)
26            .category(Category::Math)
27    }
28
29    fn description(&self) -> &str {
30        "Returns the variance of a list of numbers or of each column in a table."
31    }
32
33    fn search_terms(&self) -> Vec<&str> {
34        vec!["deviation", "dispersion", "variation", "statistics"]
35    }
36
37    fn is_const(&self) -> bool {
38        true
39    }
40
41    fn run(
42        &self,
43        engine_state: &EngineState,
44        stack: &mut Stack,
45        call: &Call,
46        input: PipelineData,
47    ) -> Result<PipelineData, ShellError> {
48        let sample = call.has_flag(engine_state, stack, "sample")?;
49        let name = call.head;
50        let span = input.span().unwrap_or(name);
51        let input: PipelineData = match input.try_expand_range() {
52            Err(_) => {
53                return Err(ShellError::IncorrectValue {
54                    msg: "Range must be bounded".to_string(),
55                    val_span: span,
56                    call_span: name,
57                });
58            }
59            Ok(val) => val,
60        };
61        run_with_function(call, input, compute_variance(sample))
62    }
63
64    fn run_const(
65        &self,
66        working_set: &StateWorkingSet,
67        call: &Call,
68        input: PipelineData,
69    ) -> Result<PipelineData, ShellError> {
70        let sample = call.has_flag_const(working_set, "sample")?;
71        let name = call.head;
72        let span = input.span().unwrap_or(name);
73        let input: PipelineData = match input.try_expand_range() {
74            Err(_) => {
75                return Err(ShellError::IncorrectValue {
76                    msg: "Range must be bounded".to_string(),
77                    val_span: span,
78                    call_span: name,
79                });
80            }
81            Ok(val) => val,
82        };
83        run_with_function(call, input, compute_variance(sample))
84    }
85
86    fn examples(&self) -> Vec<Example> {
87        vec![
88            Example {
89                description: "Get the variance of a list of numbers",
90                example: "[1 2 3 4 5] | math variance",
91                result: Some(Value::test_float(2.0)),
92            },
93            Example {
94                description: "Get the sample variance of a list of numbers",
95                example: "[1 2 3 4 5] | math variance --sample",
96                result: Some(Value::test_float(2.5)),
97            },
98            Example {
99                description: "Compute the variance of each column in a table",
100                example: "[[a b]; [1 2] [3 4]] | math variance",
101                result: Some(Value::test_record(record! {
102                    "a" => Value::test_int(1),
103                    "b" => Value::test_int(1),
104                })),
105            },
106        ]
107    }
108}
109
110fn sum_of_squares(values: &[Value], span: Span) -> Result<Value, ShellError> {
111    let n = Value::int(values.len() as i64, span);
112    let mut sum_x = Value::int(0, span);
113    let mut sum_x2 = Value::int(0, span);
114    for value in values {
115        let v = match &value {
116            Value::Int { .. } | Value::Float { .. } => Ok(value),
117            Value::Error { error, .. } => Err(*error.clone()),
118            other => Err(ShellError::UnsupportedInput {
119                msg: format!("Attempted to compute the sum of squares of a non-int, non-float value '{}' with a type of `{}`.",
120                        other.coerce_string()?, other.get_type()),
121                input: "value originates from here".into(),
122                msg_span: span,
123                input_span: value.span(),
124            }),
125        }?;
126        let v_squared = &v.mul(span, v, span)?;
127        sum_x2 = sum_x2.add(span, v_squared, span)?;
128        sum_x = sum_x.add(span, v, span)?;
129    }
130
131    let sum_x_squared = sum_x.mul(span, &sum_x, span)?;
132    let sum_x_squared_div_n = sum_x_squared.div(span, &n, span)?;
133
134    let ss = sum_x2.sub(span, &sum_x_squared_div_n, span)?;
135
136    Ok(ss)
137}
138
139pub fn compute_variance(
140    sample: bool,
141) -> impl Fn(&[Value], Span, Span) -> Result<Value, ShellError> {
142    move |values: &[Value], span: Span, head: Span| {
143        let n = if sample {
144            values.len() - 1
145        } else {
146            values.len()
147        };
148        // sum_of_squares() needs the span of the original value, not the call head.
149        let ss = sum_of_squares(values, span)?;
150        let n = Value::int(n as i64, head);
151        ss.div(head, &n, head)
152    }
153}
154
155#[cfg(test)]
156mod test {
157    use super::*;
158
159    #[test]
160    fn test_examples() {
161        use crate::test_examples;
162
163        test_examples(MathVariance {})
164    }
165}