reifydb_function/math/scalar/
sqrt.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{
9 ScalarFunction, ScalarFunctionContext,
10 error::{ScalarFunctionError, ScalarFunctionResult},
11 propagate_options,
12};
13
14pub struct Sqrt;
15
16impl Sqrt {
17 pub fn new() -> Self {
18 Self
19 }
20}
21
22fn numeric_to_f64(data: &ColumnData, i: usize) -> Option<f64> {
23 match data {
24 ColumnData::Int1(c) => c.get(i).map(|&v| v as f64),
25 ColumnData::Int2(c) => c.get(i).map(|&v| v as f64),
26 ColumnData::Int4(c) => c.get(i).map(|&v| v as f64),
27 ColumnData::Int8(c) => c.get(i).map(|&v| v as f64),
28 ColumnData::Int16(c) => c.get(i).map(|&v| v as f64),
29 ColumnData::Uint1(c) => c.get(i).map(|&v| v as f64),
30 ColumnData::Uint2(c) => c.get(i).map(|&v| v as f64),
31 ColumnData::Uint4(c) => c.get(i).map(|&v| v as f64),
32 ColumnData::Uint8(c) => c.get(i).map(|&v| v as f64),
33 ColumnData::Uint16(c) => c.get(i).map(|&v| v as f64),
34 ColumnData::Float4(c) => c.get(i).map(|&v| v as f64),
35 ColumnData::Float8(c) => c.get(i).copied(),
36 ColumnData::Int {
37 container,
38 ..
39 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
40 ColumnData::Uint {
41 container,
42 ..
43 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
44 ColumnData::Decimal {
45 container,
46 ..
47 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
48 _ => None,
49 }
50}
51
52impl ScalarFunction for Sqrt {
53 fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
54 if let Some(result) = propagate_options(self, &ctx) {
55 return result;
56 }
57 let columns = ctx.columns;
58 let row_count = ctx.row_count;
59
60 if columns.len() != 1 {
61 return Err(ScalarFunctionError::ArityMismatch {
62 function: ctx.fragment.clone(),
63 expected: 1,
64 actual: columns.len(),
65 });
66 }
67
68 let column = columns.get(0).unwrap();
69
70 if !column.data().get_type().is_number() {
71 return Err(ScalarFunctionError::InvalidArgumentType {
72 function: ctx.fragment.clone(),
73 argument_index: 0,
74 expected: vec![
75 Type::Int1,
76 Type::Int2,
77 Type::Int4,
78 Type::Int8,
79 Type::Int16,
80 Type::Uint1,
81 Type::Uint2,
82 Type::Uint4,
83 Type::Uint8,
84 Type::Uint16,
85 Type::Float4,
86 Type::Float8,
87 Type::Int,
88 Type::Uint,
89 Type::Decimal,
90 ],
91 actual: column.data().get_type(),
92 });
93 }
94
95 let mut result = Vec::with_capacity(row_count);
96 let mut bitvec = Vec::with_capacity(row_count);
97
98 for i in 0..row_count {
99 match numeric_to_f64(column.data(), i) {
100 Some(v) => {
101 result.push(v.sqrt());
102 bitvec.push(true);
103 }
104 None => {
105 result.push(0.0);
106 bitvec.push(false);
107 }
108 }
109 }
110
111 Ok(ColumnData::float8_with_bitvec(result, bitvec))
112 }
113
114 fn return_type(&self, _input_types: &[Type]) -> Type {
115 Type::Float8
116 }
117}