reifydb_function/math/scalar/
modulo.rs1use num_traits::ToPrimitive;
5use reifydb_core::value::column::data::ColumnData;
6use reifydb_type::value::r#type::Type;
7
8use crate::{
9 ScalarFunction, ScalarFunctionContext,
10 error::{ScalarFunctionError, ScalarFunctionResult},
11 propagate_options,
12};
13
14pub struct Modulo;
15
16impl Modulo {
17 pub fn new() -> Self {
18 Self
19 }
20}
21
22fn numeric_to_f64(data: &ColumnData, i: usize) -> Option<f64> {
23 match data {
24 ColumnData::Int1(c) => c.get(i).map(|&v| v as f64),
25 ColumnData::Int2(c) => c.get(i).map(|&v| v as f64),
26 ColumnData::Int4(c) => c.get(i).map(|&v| v as f64),
27 ColumnData::Int8(c) => c.get(i).map(|&v| v as f64),
28 ColumnData::Int16(c) => c.get(i).map(|&v| v as f64),
29 ColumnData::Uint1(c) => c.get(i).map(|&v| v as f64),
30 ColumnData::Uint2(c) => c.get(i).map(|&v| v as f64),
31 ColumnData::Uint4(c) => c.get(i).map(|&v| v as f64),
32 ColumnData::Uint8(c) => c.get(i).map(|&v| v as f64),
33 ColumnData::Uint16(c) => c.get(i).map(|&v| v as f64),
34 ColumnData::Float4(c) => c.get(i).map(|&v| v as f64),
35 ColumnData::Float8(c) => c.get(i).copied(),
36 ColumnData::Int {
37 container,
38 ..
39 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
40 ColumnData::Uint {
41 container,
42 ..
43 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
44 ColumnData::Decimal {
45 container,
46 ..
47 } => container.get(i).map(|v| v.0.to_f64().unwrap_or(0.0)),
48 _ => None,
49 }
50}
51
52impl ScalarFunction for Modulo {
53 fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
54 if let Some(result) = propagate_options(self, &ctx) {
55 return result;
56 }
57 let columns = ctx.columns;
58 let row_count = ctx.row_count;
59
60 if columns.len() != 2 {
61 return Err(ScalarFunctionError::ArityMismatch {
62 function: ctx.fragment.clone(),
63 expected: 2,
64 actual: columns.len(),
65 });
66 }
67
68 let a_col = columns.get(0).unwrap();
69 let b_col = columns.get(1).unwrap();
70
71 if !a_col.data().get_type().is_number() {
72 return Err(ScalarFunctionError::InvalidArgumentType {
73 function: ctx.fragment.clone(),
74 argument_index: 0,
75 expected: vec![
76 Type::Int1,
77 Type::Int2,
78 Type::Int4,
79 Type::Int8,
80 Type::Int16,
81 Type::Uint1,
82 Type::Uint2,
83 Type::Uint4,
84 Type::Uint8,
85 Type::Uint16,
86 Type::Float4,
87 Type::Float8,
88 Type::Int,
89 Type::Uint,
90 Type::Decimal,
91 ],
92 actual: a_col.data().get_type(),
93 });
94 }
95
96 if !b_col.data().get_type().is_number() {
97 return Err(ScalarFunctionError::InvalidArgumentType {
98 function: ctx.fragment.clone(),
99 argument_index: 1,
100 expected: vec![
101 Type::Int1,
102 Type::Int2,
103 Type::Int4,
104 Type::Int8,
105 Type::Int16,
106 Type::Uint1,
107 Type::Uint2,
108 Type::Uint4,
109 Type::Uint8,
110 Type::Uint16,
111 Type::Float4,
112 Type::Float8,
113 Type::Int,
114 Type::Uint,
115 Type::Decimal,
116 ],
117 actual: b_col.data().get_type(),
118 });
119 }
120
121 let mut result = Vec::with_capacity(row_count);
122 let mut bitvec = Vec::with_capacity(row_count);
123
124 for i in 0..row_count {
125 match (numeric_to_f64(a_col.data(), i), numeric_to_f64(b_col.data(), i)) {
126 (Some(a), Some(b)) => {
127 if b == 0.0 {
128 result.push(f64::NAN);
129 } else {
130 result.push(a % b);
131 }
132 bitvec.push(true);
133 }
134 _ => {
135 result.push(0.0);
136 bitvec.push(false);
137 }
138 }
139 }
140
141 Ok(ColumnData::float8_with_bitvec(result, bitvec))
142 }
143
144 fn return_type(&self, _input_types: &[Type]) -> Type {
145 Type::Float8
146 }
147}