Skip to main content

datafusion_functions/math/
cot.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::AsArray;
21use arrow::datatypes::DataType::{Float32, Float64};
22use arrow::datatypes::{DataType, Float32Type, Float64Type};
23
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{Result, ScalarValue, internal_err};
26use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30#[user_doc(
31    doc_section(label = "Math Functions"),
32    description = "Returns the cotangent of a number.",
33    syntax_example = r#"cot(numeric_expression)"#,
34    sql_example = r#"```sql
35> SELECT cot(1);
36+---------+
37| cot(1)  |
38+---------+
39| 0.64209 |
40+---------+
41```"#,
42    standard_argument(name = "numeric_expression", prefix = "Numeric")
43)]
44#[derive(Debug, PartialEq, Eq, Hash)]
45pub struct CotFunc {
46    signature: Signature,
47}
48
49impl Default for CotFunc {
50    fn default() -> Self {
51        CotFunc::new()
52    }
53}
54
55impl CotFunc {
56    pub fn new() -> Self {
57        use DataType::*;
58        Self {
59            // math expressions expect 1 argument of type f64 or f32
60            // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
61            // return the best approximation for it (in f64).
62            // We accept f32 because in this case it is clear that the best approximation
63            // will be as good as the number of digits in the number
64            signature: Signature::uniform(
65                1,
66                vec![Float64, Float32],
67                Volatility::Immutable,
68            ),
69        }
70    }
71}
72
73impl ScalarUDFImpl for CotFunc {
74    fn name(&self) -> &str {
75        "cot"
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
83        match arg_types[0] {
84            Float32 => Ok(Float32),
85            _ => Ok(Float64),
86        }
87    }
88
89    fn documentation(&self) -> Option<&Documentation> {
90        self.doc()
91    }
92
93    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94        let return_field = args.return_field;
95        let [arg] = take_function_args(self.name(), args.args)?;
96
97        match arg {
98            ColumnarValue::Scalar(scalar) => {
99                if scalar.is_null() {
100                    return ColumnarValue::Scalar(ScalarValue::Null)
101                        .cast_to(return_field.data_type(), None);
102                }
103
104                match scalar {
105                    ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar(
106                        ScalarValue::Float64(Some(compute_cot64(v))),
107                    )),
108                    ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar(
109                        ScalarValue::Float32(Some(compute_cot32(v))),
110                    )),
111                    _ => {
112                        internal_err!(
113                            "Unexpected scalar type for cot: {:?}",
114                            scalar.data_type()
115                        )
116                    }
117                }
118            }
119            ColumnarValue::Array(array) => match array.data_type() {
120                Float64 => Ok(ColumnarValue::Array(Arc::new(
121                    array
122                        .as_primitive::<Float64Type>()
123                        .unary::<_, Float64Type>(compute_cot64),
124                ))),
125                Float32 => Ok(ColumnarValue::Array(Arc::new(
126                    array
127                        .as_primitive::<Float32Type>()
128                        .unary::<_, Float32Type>(compute_cot32),
129                ))),
130                other => {
131                    internal_err!("Unexpected data type {other:?} for function cot")
132                }
133            },
134        }
135    }
136}
137
138fn compute_cot32(x: f32) -> f32 {
139    let a = f32::tan(x);
140    1.0 / a
141}
142
143fn compute_cot64(x: f64) -> f64 {
144    let a = f64::tan(x);
145    1.0 / a
146}
147
148#[cfg(test)]
149mod test {
150    use std::sync::Arc;
151
152    use arrow::array::{ArrayRef, Float32Array, Float64Array};
153    use arrow::datatypes::{DataType, Field};
154    use datafusion_common::ScalarValue;
155    use datafusion_common::cast::{as_float32_array, as_float64_array};
156    use datafusion_common::config::ConfigOptions;
157    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
158
159    use crate::math::cot::CotFunc;
160
161    #[test]
162    fn test_cot_f32() {
163        let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]));
164        let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
165        let args = ScalarFunctionArgs {
166            args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
167            arg_fields,
168            number_rows: array.len(),
169            return_field: Field::new("f", DataType::Float32, true).into(),
170            config_options: Arc::new(ConfigOptions::default()),
171        };
172        let result = CotFunc::new()
173            .invoke_with_args(args)
174            .expect("failed to initialize function cot");
175
176        match result {
177            ColumnarValue::Array(arr) => {
178                let floats = as_float32_array(&arr)
179                    .expect("failed to convert result to a Float32Array");
180
181                let expected = Float32Array::from(vec![
182                    -1.986_460_4,
183                    -0.156_119_96,
184                    -0.501_202_8,
185                    0.156_119_96,
186                ]);
187
188                let eps = 1e-6;
189                assert_eq!(floats.len(), 4);
190                assert!((floats.value(0) - expected.value(0)).abs() < eps);
191                assert!((floats.value(1) - expected.value(1)).abs() < eps);
192                assert!((floats.value(2) - expected.value(2)).abs() < eps);
193                assert!((floats.value(3) - expected.value(3)).abs() < eps);
194            }
195            ColumnarValue::Scalar(_) => {
196                panic!("Expected an array value")
197            }
198        }
199    }
200
201    #[test]
202    fn test_cot_f64() {
203        let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]));
204        let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
205        let args = ScalarFunctionArgs {
206            args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
207            arg_fields,
208            number_rows: array.len(),
209            return_field: Field::new("f", DataType::Float64, true).into(),
210            config_options: Arc::new(ConfigOptions::default()),
211        };
212        let result = CotFunc::new()
213            .invoke_with_args(args)
214            .expect("failed to initialize function cot");
215
216        match result {
217            ColumnarValue::Array(arr) => {
218                let floats = as_float64_array(&arr)
219                    .expect("failed to convert result to a Float64Array");
220
221                let expected = Float64Array::from(vec![
222                    -1.986_458_685_881_4,
223                    -0.156_119_952_161_6,
224                    -0.501_202_783_380_1,
225                    0.156_119_952_161_6,
226                ]);
227
228                let eps = 1e-12;
229                assert_eq!(floats.len(), 4);
230                assert!((floats.value(0) - expected.value(0)).abs() < eps);
231                assert!((floats.value(1) - expected.value(1)).abs() < eps);
232                assert!((floats.value(2) - expected.value(2)).abs() < eps);
233                assert!((floats.value(3) - expected.value(3)).abs() < eps);
234            }
235            ColumnarValue::Scalar(_) => {
236                panic!("Expected an array value")
237            }
238        }
239    }
240
241    #[test]
242    fn test_cot_scalar_f64() {
243        let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
244        let args = ScalarFunctionArgs {
245            args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))],
246            arg_fields,
247            number_rows: 1,
248            return_field: Field::new("f", DataType::Float64, false).into(),
249            config_options: Arc::new(ConfigOptions::default()),
250        };
251        let result = CotFunc::new()
252            .invoke_with_args(args)
253            .expect("cot scalar should succeed");
254
255        match result {
256            ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
257                // cot(1.0) = 1/tan(1.0) ≈ 0.6420926159343306
258                let expected = 1.0_f64 / 1.0_f64.tan();
259                assert!((v - expected).abs() < 1e-12);
260            }
261            _ => panic!("Expected Float64 scalar"),
262        }
263    }
264
265    #[test]
266    fn test_cot_scalar_f32() {
267        let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
268        let args = ScalarFunctionArgs {
269            args: vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))],
270            arg_fields,
271            number_rows: 1,
272            return_field: Field::new("f", DataType::Float32, false).into(),
273            config_options: Arc::new(ConfigOptions::default()),
274        };
275        let result = CotFunc::new()
276            .invoke_with_args(args)
277            .expect("cot scalar should succeed");
278
279        match result {
280            ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => {
281                let expected = 1.0_f32 / 1.0_f32.tan();
282                assert!((v - expected).abs() < 1e-6);
283            }
284            _ => panic!("Expected Float32 scalar"),
285        }
286    }
287
288    #[test]
289    fn test_cot_scalar_null() {
290        let arg_fields = vec![Field::new("a", DataType::Float64, true).into()];
291        let args = ScalarFunctionArgs {
292            args: vec![ColumnarValue::Scalar(ScalarValue::Float64(None))],
293            arg_fields,
294            number_rows: 1,
295            return_field: Field::new("f", DataType::Float64, true).into(),
296            config_options: Arc::new(ConfigOptions::default()),
297        };
298        let result = CotFunc::new()
299            .invoke_with_args(args)
300            .expect("cot null should succeed");
301
302        match result {
303            ColumnarValue::Scalar(scalar) => {
304                assert!(scalar.is_null());
305            }
306            _ => panic!("Expected scalar result"),
307        }
308    }
309
310    #[test]
311    fn test_cot_scalar_zero() {
312        let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
313        let args = ScalarFunctionArgs {
314            args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))],
315            arg_fields,
316            number_rows: 1,
317            return_field: Field::new("f", DataType::Float64, false).into(),
318            config_options: Arc::new(ConfigOptions::default()),
319        };
320        let result = CotFunc::new()
321            .invoke_with_args(args)
322            .expect("cot zero should succeed");
323
324        match result {
325            ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
326                // cot(0) = 1/tan(0) = infinity
327                assert!(v.is_infinite());
328            }
329            _ => panic!("Expected Float64 scalar"),
330        }
331    }
332
333    #[test]
334    fn test_cot_scalar_pi() {
335        let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
336        let args = ScalarFunctionArgs {
337            args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(
338                std::f64::consts::PI,
339            )))],
340            arg_fields,
341            number_rows: 1,
342            return_field: Field::new("f", DataType::Float64, false).into(),
343            config_options: Arc::new(ConfigOptions::default()),
344        };
345        let result = CotFunc::new()
346            .invoke_with_args(args)
347            .expect("cot pi should succeed");
348
349        match result {
350            ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
351                // cot(PI) = 1/tan(PI) - very large negative number due to floating point
352                let expected = 1.0_f64 / std::f64::consts::PI.tan();
353                assert!((v - expected).abs() < 1e-6);
354            }
355            _ => panic!("Expected Float64 scalar"),
356        }
357    }
358}