datafusion_functions/math/
cot.rs1use std::sync::Arc;
19
20use arrow::array::AsArray;
21use arrow::datatypes::DataType::{Float32, Float64};
22use arrow::datatypes::{DataType, Float32Type, Float64Type};
23
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{Result, ScalarValue, internal_err};
26use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30#[user_doc(
31 doc_section(label = "Math Functions"),
32 description = "Returns the cotangent of a number.",
33 syntax_example = r#"cot(numeric_expression)"#,
34 sql_example = r#"```sql
35> SELECT cot(1);
36+---------+
37| cot(1) |
38+---------+
39| 0.64209 |
40+---------+
41```"#,
42 standard_argument(name = "numeric_expression", prefix = "Numeric")
43)]
44#[derive(Debug, PartialEq, Eq, Hash)]
45pub struct CotFunc {
46 signature: Signature,
47}
48
49impl Default for CotFunc {
50 fn default() -> Self {
51 CotFunc::new()
52 }
53}
54
55impl CotFunc {
56 pub fn new() -> Self {
57 use DataType::*;
58 Self {
59 signature: Signature::uniform(
65 1,
66 vec![Float64, Float32],
67 Volatility::Immutable,
68 ),
69 }
70 }
71}
72
73impl ScalarUDFImpl for CotFunc {
74 fn name(&self) -> &str {
75 "cot"
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
83 match arg_types[0] {
84 Float32 => Ok(Float32),
85 _ => Ok(Float64),
86 }
87 }
88
89 fn documentation(&self) -> Option<&Documentation> {
90 self.doc()
91 }
92
93 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94 let return_field = args.return_field;
95 let [arg] = take_function_args(self.name(), args.args)?;
96
97 match arg {
98 ColumnarValue::Scalar(scalar) => {
99 if scalar.is_null() {
100 return ColumnarValue::Scalar(ScalarValue::Null)
101 .cast_to(return_field.data_type(), None);
102 }
103
104 match scalar {
105 ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar(
106 ScalarValue::Float64(Some(compute_cot64(v))),
107 )),
108 ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar(
109 ScalarValue::Float32(Some(compute_cot32(v))),
110 )),
111 _ => {
112 internal_err!(
113 "Unexpected scalar type for cot: {:?}",
114 scalar.data_type()
115 )
116 }
117 }
118 }
119 ColumnarValue::Array(array) => match array.data_type() {
120 Float64 => Ok(ColumnarValue::Array(Arc::new(
121 array
122 .as_primitive::<Float64Type>()
123 .unary::<_, Float64Type>(compute_cot64),
124 ))),
125 Float32 => Ok(ColumnarValue::Array(Arc::new(
126 array
127 .as_primitive::<Float32Type>()
128 .unary::<_, Float32Type>(compute_cot32),
129 ))),
130 other => {
131 internal_err!("Unexpected data type {other:?} for function cot")
132 }
133 },
134 }
135 }
136}
137
138fn compute_cot32(x: f32) -> f32 {
139 let a = f32::tan(x);
140 1.0 / a
141}
142
143fn compute_cot64(x: f64) -> f64 {
144 let a = f64::tan(x);
145 1.0 / a
146}
147
148#[cfg(test)]
149mod test {
150 use std::sync::Arc;
151
152 use arrow::array::{ArrayRef, Float32Array, Float64Array};
153 use arrow::datatypes::{DataType, Field};
154 use datafusion_common::ScalarValue;
155 use datafusion_common::cast::{as_float32_array, as_float64_array};
156 use datafusion_common::config::ConfigOptions;
157 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
158
159 use crate::math::cot::CotFunc;
160
161 #[test]
162 fn test_cot_f32() {
163 let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]));
164 let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
165 let args = ScalarFunctionArgs {
166 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
167 arg_fields,
168 number_rows: array.len(),
169 return_field: Field::new("f", DataType::Float32, true).into(),
170 config_options: Arc::new(ConfigOptions::default()),
171 };
172 let result = CotFunc::new()
173 .invoke_with_args(args)
174 .expect("failed to initialize function cot");
175
176 match result {
177 ColumnarValue::Array(arr) => {
178 let floats = as_float32_array(&arr)
179 .expect("failed to convert result to a Float32Array");
180
181 let expected = Float32Array::from(vec![
182 -1.986_460_4,
183 -0.156_119_96,
184 -0.501_202_8,
185 0.156_119_96,
186 ]);
187
188 let eps = 1e-6;
189 assert_eq!(floats.len(), 4);
190 assert!((floats.value(0) - expected.value(0)).abs() < eps);
191 assert!((floats.value(1) - expected.value(1)).abs() < eps);
192 assert!((floats.value(2) - expected.value(2)).abs() < eps);
193 assert!((floats.value(3) - expected.value(3)).abs() < eps);
194 }
195 ColumnarValue::Scalar(_) => {
196 panic!("Expected an array value")
197 }
198 }
199 }
200
201 #[test]
202 fn test_cot_f64() {
203 let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]));
204 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
205 let args = ScalarFunctionArgs {
206 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
207 arg_fields,
208 number_rows: array.len(),
209 return_field: Field::new("f", DataType::Float64, true).into(),
210 config_options: Arc::new(ConfigOptions::default()),
211 };
212 let result = CotFunc::new()
213 .invoke_with_args(args)
214 .expect("failed to initialize function cot");
215
216 match result {
217 ColumnarValue::Array(arr) => {
218 let floats = as_float64_array(&arr)
219 .expect("failed to convert result to a Float64Array");
220
221 let expected = Float64Array::from(vec![
222 -1.986_458_685_881_4,
223 -0.156_119_952_161_6,
224 -0.501_202_783_380_1,
225 0.156_119_952_161_6,
226 ]);
227
228 let eps = 1e-12;
229 assert_eq!(floats.len(), 4);
230 assert!((floats.value(0) - expected.value(0)).abs() < eps);
231 assert!((floats.value(1) - expected.value(1)).abs() < eps);
232 assert!((floats.value(2) - expected.value(2)).abs() < eps);
233 assert!((floats.value(3) - expected.value(3)).abs() < eps);
234 }
235 ColumnarValue::Scalar(_) => {
236 panic!("Expected an array value")
237 }
238 }
239 }
240
241 #[test]
242 fn test_cot_scalar_f64() {
243 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
244 let args = ScalarFunctionArgs {
245 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))],
246 arg_fields,
247 number_rows: 1,
248 return_field: Field::new("f", DataType::Float64, false).into(),
249 config_options: Arc::new(ConfigOptions::default()),
250 };
251 let result = CotFunc::new()
252 .invoke_with_args(args)
253 .expect("cot scalar should succeed");
254
255 match result {
256 ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
257 let expected = 1.0_f64 / 1.0_f64.tan();
259 assert!((v - expected).abs() < 1e-12);
260 }
261 _ => panic!("Expected Float64 scalar"),
262 }
263 }
264
265 #[test]
266 fn test_cot_scalar_f32() {
267 let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
268 let args = ScalarFunctionArgs {
269 args: vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))],
270 arg_fields,
271 number_rows: 1,
272 return_field: Field::new("f", DataType::Float32, false).into(),
273 config_options: Arc::new(ConfigOptions::default()),
274 };
275 let result = CotFunc::new()
276 .invoke_with_args(args)
277 .expect("cot scalar should succeed");
278
279 match result {
280 ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => {
281 let expected = 1.0_f32 / 1.0_f32.tan();
282 assert!((v - expected).abs() < 1e-6);
283 }
284 _ => panic!("Expected Float32 scalar"),
285 }
286 }
287
288 #[test]
289 fn test_cot_scalar_null() {
290 let arg_fields = vec![Field::new("a", DataType::Float64, true).into()];
291 let args = ScalarFunctionArgs {
292 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(None))],
293 arg_fields,
294 number_rows: 1,
295 return_field: Field::new("f", DataType::Float64, true).into(),
296 config_options: Arc::new(ConfigOptions::default()),
297 };
298 let result = CotFunc::new()
299 .invoke_with_args(args)
300 .expect("cot null should succeed");
301
302 match result {
303 ColumnarValue::Scalar(scalar) => {
304 assert!(scalar.is_null());
305 }
306 _ => panic!("Expected scalar result"),
307 }
308 }
309
310 #[test]
311 fn test_cot_scalar_zero() {
312 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
313 let args = ScalarFunctionArgs {
314 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))],
315 arg_fields,
316 number_rows: 1,
317 return_field: Field::new("f", DataType::Float64, false).into(),
318 config_options: Arc::new(ConfigOptions::default()),
319 };
320 let result = CotFunc::new()
321 .invoke_with_args(args)
322 .expect("cot zero should succeed");
323
324 match result {
325 ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
326 assert!(v.is_infinite());
328 }
329 _ => panic!("Expected Float64 scalar"),
330 }
331 }
332
333 #[test]
334 fn test_cot_scalar_pi() {
335 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
336 let args = ScalarFunctionArgs {
337 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(
338 std::f64::consts::PI,
339 )))],
340 arg_fields,
341 number_rows: 1,
342 return_field: Field::new("f", DataType::Float64, false).into(),
343 config_options: Arc::new(ConfigOptions::default()),
344 };
345 let result = CotFunc::new()
346 .invoke_with_args(args)
347 .expect("cot pi should succeed");
348
349 match result {
350 ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
351 let expected = 1.0_f64 / std::f64::consts::PI.tan();
353 assert!((v - expected).abs() < 1e-6);
354 }
355 _ => panic!("Expected Float64 scalar"),
356 }
357 }
358}