1use crate::math::utils::run_with_function;
2use nu_engine::command_prelude::*;
3
4#[derive(Clone)]
5pub struct MathVariance;
6
7impl Command for MathVariance {
8 fn name(&self) -> &str {
9 "math variance"
10 }
11
12 fn signature(&self) -> Signature {
13 Signature::build("math variance")
14 .input_output_types(vec![
15 (Type::List(Box::new(Type::Number)), Type::Number),
16 (Type::Range, Type::Number),
17 (Type::table(), Type::record()),
18 (Type::record(), Type::record()),
19 ])
20 .switch(
21 "sample",
22 "calculate sample variance (i.e. using N-1 as the denominator)",
23 Some('s'),
24 )
25 .allow_variants_without_examples(true)
26 .category(Category::Math)
27 }
28
29 fn description(&self) -> &str {
30 "Returns the variance of a list of numbers or of each column in a table."
31 }
32
33 fn search_terms(&self) -> Vec<&str> {
34 vec!["deviation", "dispersion", "variation", "statistics"]
35 }
36
37 fn is_const(&self) -> bool {
38 true
39 }
40
41 fn run(
42 &self,
43 engine_state: &EngineState,
44 stack: &mut Stack,
45 call: &Call,
46 input: PipelineData,
47 ) -> Result<PipelineData, ShellError> {
48 let sample = call.has_flag(engine_state, stack, "sample")?;
49 let name = call.head;
50 let span = input.span().unwrap_or(name);
51 let input: PipelineData = match input.try_expand_range() {
52 Err(_) => {
53 return Err(ShellError::IncorrectValue {
54 msg: "Range must be bounded".to_string(),
55 val_span: span,
56 call_span: name,
57 });
58 }
59 Ok(val) => val,
60 };
61 run_with_function(call, input, compute_variance(sample))
62 }
63
64 fn run_const(
65 &self,
66 working_set: &StateWorkingSet,
67 call: &Call,
68 input: PipelineData,
69 ) -> Result<PipelineData, ShellError> {
70 let sample = call.has_flag_const(working_set, "sample")?;
71 let name = call.head;
72 let span = input.span().unwrap_or(name);
73 let input: PipelineData = match input.try_expand_range() {
74 Err(_) => {
75 return Err(ShellError::IncorrectValue {
76 msg: "Range must be bounded".to_string(),
77 val_span: span,
78 call_span: name,
79 });
80 }
81 Ok(val) => val,
82 };
83 run_with_function(call, input, compute_variance(sample))
84 }
85
86 fn examples(&self) -> Vec<Example> {
87 vec![
88 Example {
89 description: "Get the variance of a list of numbers",
90 example: "[1 2 3 4 5] | math variance",
91 result: Some(Value::test_float(2.0)),
92 },
93 Example {
94 description: "Get the sample variance of a list of numbers",
95 example: "[1 2 3 4 5] | math variance --sample",
96 result: Some(Value::test_float(2.5)),
97 },
98 Example {
99 description: "Compute the variance of each column in a table",
100 example: "[[a b]; [1 2] [3 4]] | math variance",
101 result: Some(Value::test_record(record! {
102 "a" => Value::test_int(1),
103 "b" => Value::test_int(1),
104 })),
105 },
106 ]
107 }
108}
109
110fn sum_of_squares(values: &[Value], span: Span) -> Result<Value, ShellError> {
111 let n = Value::int(values.len() as i64, span);
112 let mut sum_x = Value::int(0, span);
113 let mut sum_x2 = Value::int(0, span);
114 for value in values {
115 let v = match &value {
116 Value::Int { .. } | Value::Float { .. } => Ok(value),
117 Value::Error { error, .. } => Err(*error.clone()),
118 other => Err(ShellError::UnsupportedInput {
119 msg: format!(
120 "Attempted to compute the sum of squares of a non-int, non-float value '{}' with a type of `{}`.",
121 other.coerce_string()?,
122 other.get_type()
123 ),
124 input: "value originates from here".into(),
125 msg_span: span,
126 input_span: value.span(),
127 }),
128 }?;
129 let v_squared = &v.mul(span, v, span)?;
130 sum_x2 = sum_x2.add(span, v_squared, span)?;
131 sum_x = sum_x.add(span, v, span)?;
132 }
133
134 let sum_x_squared = sum_x.mul(span, &sum_x, span)?;
135 let sum_x_squared_div_n = sum_x_squared.div(span, &n, span)?;
136
137 let ss = sum_x2.sub(span, &sum_x_squared_div_n, span)?;
138
139 Ok(ss)
140}
141
142pub fn compute_variance(
143 sample: bool,
144) -> impl Fn(&[Value], Span, Span) -> Result<Value, ShellError> {
145 move |values: &[Value], span: Span, head: Span| {
146 let n = if sample {
147 values.len() - 1
148 } else {
149 values.len()
150 };
151 let ss = sum_of_squares(values, span)?;
153 let n = Value::int(n as i64, head);
154 ss.div(head, &n, head)
155 }
156}
157
158#[cfg(test)]
159mod test {
160 use super::*;
161
162 #[test]
163 fn test_examples() {
164 use crate::test_examples;
165
166 test_examples(MathVariance {})
167 }
168}