reifydb_function/math/scalar/
sign.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{
9 ScalarFunction, ScalarFunctionContext,
10 error::{ScalarFunctionError, ScalarFunctionResult},
11 propagate_options,
12};
13
14pub struct Sign;
15
16impl Sign {
17 pub fn new() -> Self {
18 Self
19 }
20}
21
22fn numeric_to_f64(data: &ColumnData, i: usize) -> Option<f64> {
23 match data {
24 ColumnData::Int1(c) => c.get(i).map(|&v| v as f64),
25 ColumnData::Int2(c) => c.get(i).map(|&v| v as f64),
26 ColumnData::Int4(c) => c.get(i).map(|&v| v as f64),
27 ColumnData::Int8(c) => c.get(i).map(|&v| v as f64),
28 ColumnData::Int16(c) => c.get(i).map(|&v| v as f64),
29 ColumnData::Uint1(c) => c.get(i).map(|&v| v as f64),
30 ColumnData::Uint2(c) => c.get(i).map(|&v| v as f64),
31 ColumnData::Uint4(c) => c.get(i).map(|&v| v as f64),
32 ColumnData::Uint8(c) => c.get(i).map(|&v| v as f64),
33 ColumnData::Uint16(c) => c.get(i).map(|&v| v as f64),
34 ColumnData::Float4(c) => c.get(i).map(|&v| v as f64),
35 ColumnData::Float8(c) => c.get(i).copied(),
36 ColumnData::Int {
37 container,
38 ..
39 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
40 ColumnData::Uint {
41 container,
42 ..
43 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
44 ColumnData::Decimal {
45 container,
46 ..
47 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
48 _ => None,
49 }
50}
51
52impl ScalarFunction for Sign {
53 fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
54 if let Some(result) = propagate_options(self, &ctx) {
55 return result;
56 }
57 let columns = ctx.columns;
58 let row_count = ctx.row_count;
59
60 if columns.len() != 1 {
61 return Err(ScalarFunctionError::ArityMismatch {
62 function: ctx.fragment.clone(),
63 expected: 1,
64 actual: columns.len(),
65 });
66 }
67
68 let column = columns.get(0).unwrap();
69
70 if !column.data().get_type().is_number() {
71 return Err(ScalarFunctionError::InvalidArgumentType {
72 function: ctx.fragment.clone(),
73 argument_index: 0,
74 expected: vec![
75 Type::Int1,
76 Type::Int2,
77 Type::Int4,
78 Type::Int8,
79 Type::Int16,
80 Type::Uint1,
81 Type::Uint2,
82 Type::Uint4,
83 Type::Uint8,
84 Type::Uint16,
85 Type::Float4,
86 Type::Float8,
87 Type::Int,
88 Type::Uint,
89 Type::Decimal,
90 ],
91 actual: column.data().get_type(),
92 });
93 }
94
95 let mut result = Vec::with_capacity(row_count);
96 let mut bitvec = Vec::with_capacity(row_count);
97
98 for i in 0..row_count {
99 match numeric_to_f64(column.data(), i) {
100 Some(v) => {
101 let sign = if v > 0.0 {
102 1i32
103 } else if v < 0.0 {
104 -1i32
105 } else {
106 0i32
107 };
108 result.push(sign);
109 bitvec.push(true);
110 }
111 None => {
112 result.push(0);
113 bitvec.push(false);
114 }
115 }
116 }
117
118 Ok(ColumnData::int4_with_bitvec(result, bitvec))
119 }
120
121 fn return_type(&self, _input_types: &[Type]) -> Type {
122 Type::Float8
123 }
124}