datafusion_functions/math/
cot.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::AsArray;
22use arrow::datatypes::DataType::{Float32, Float64};
23use arrow::datatypes::{DataType, Float32Type, Float64Type};
24
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, ScalarValue, internal_err};
27use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
28use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29use datafusion_macros::user_doc;
30
31#[user_doc(
32 doc_section(label = "Math Functions"),
33 description = "Returns the cotangent of a number.",
34 syntax_example = r#"cot(numeric_expression)"#,
35 sql_example = r#"```sql
36> SELECT cot(1);
37+---------+
38| cot(1) |
39+---------+
40| 0.64209 |
41+---------+
42```"#,
43 standard_argument(name = "numeric_expression", prefix = "Numeric")
44)]
45#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct CotFunc {
47 signature: Signature,
48}
49
50impl Default for CotFunc {
51 fn default() -> Self {
52 CotFunc::new()
53 }
54}
55
56impl CotFunc {
57 pub fn new() -> Self {
58 use DataType::*;
59 Self {
60 signature: Signature::uniform(
66 1,
67 vec![Float64, Float32],
68 Volatility::Immutable,
69 ),
70 }
71 }
72}
73
74impl ScalarUDFImpl for CotFunc {
75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 fn name(&self) -> &str {
80 "cot"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88 match arg_types[0] {
89 Float32 => Ok(Float32),
90 _ => Ok(Float64),
91 }
92 }
93
94 fn documentation(&self) -> Option<&Documentation> {
95 self.doc()
96 }
97
98 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99 let return_field = args.return_field;
100 let [arg] = take_function_args(self.name(), args.args)?;
101
102 match arg {
103 ColumnarValue::Scalar(scalar) => {
104 if scalar.is_null() {
105 return ColumnarValue::Scalar(ScalarValue::Null)
106 .cast_to(return_field.data_type(), None);
107 }
108
109 match scalar {
110 ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar(
111 ScalarValue::Float64(Some(compute_cot64(v))),
112 )),
113 ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar(
114 ScalarValue::Float32(Some(compute_cot32(v))),
115 )),
116 _ => {
117 internal_err!(
118 "Unexpected scalar type for cot: {:?}",
119 scalar.data_type()
120 )
121 }
122 }
123 }
124 ColumnarValue::Array(array) => match array.data_type() {
125 Float64 => Ok(ColumnarValue::Array(Arc::new(
126 array
127 .as_primitive::<Float64Type>()
128 .unary::<_, Float64Type>(compute_cot64),
129 ))),
130 Float32 => Ok(ColumnarValue::Array(Arc::new(
131 array
132 .as_primitive::<Float32Type>()
133 .unary::<_, Float32Type>(compute_cot32),
134 ))),
135 other => {
136 internal_err!("Unexpected data type {other:?} for function cot")
137 }
138 },
139 }
140 }
141}
142
143fn compute_cot32(x: f32) -> f32 {
144 let a = f32::tan(x);
145 1.0 / a
146}
147
148fn compute_cot64(x: f64) -> f64 {
149 let a = f64::tan(x);
150 1.0 / a
151}
152
153#[cfg(test)]
154mod test {
155 use std::sync::Arc;
156
157 use arrow::array::{ArrayRef, Float32Array, Float64Array};
158 use arrow::datatypes::{DataType, Field};
159 use datafusion_common::ScalarValue;
160 use datafusion_common::cast::{as_float32_array, as_float64_array};
161 use datafusion_common::config::ConfigOptions;
162 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
163
164 use crate::math::cot::CotFunc;
165
166 #[test]
167 fn test_cot_f32() {
168 let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]));
169 let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
170 let args = ScalarFunctionArgs {
171 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
172 arg_fields,
173 number_rows: array.len(),
174 return_field: Field::new("f", DataType::Float32, true).into(),
175 config_options: Arc::new(ConfigOptions::default()),
176 };
177 let result = CotFunc::new()
178 .invoke_with_args(args)
179 .expect("failed to initialize function cot");
180
181 match result {
182 ColumnarValue::Array(arr) => {
183 let floats = as_float32_array(&arr)
184 .expect("failed to convert result to a Float32Array");
185
186 let expected = Float32Array::from(vec![
187 -1.986_460_4,
188 -0.156_119_96,
189 -0.501_202_8,
190 0.156_119_96,
191 ]);
192
193 let eps = 1e-6;
194 assert_eq!(floats.len(), 4);
195 assert!((floats.value(0) - expected.value(0)).abs() < eps);
196 assert!((floats.value(1) - expected.value(1)).abs() < eps);
197 assert!((floats.value(2) - expected.value(2)).abs() < eps);
198 assert!((floats.value(3) - expected.value(3)).abs() < eps);
199 }
200 ColumnarValue::Scalar(_) => {
201 panic!("Expected an array value")
202 }
203 }
204 }
205
206 #[test]
207 fn test_cot_f64() {
208 let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]));
209 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
210 let args = ScalarFunctionArgs {
211 args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
212 arg_fields,
213 number_rows: array.len(),
214 return_field: Field::new("f", DataType::Float64, true).into(),
215 config_options: Arc::new(ConfigOptions::default()),
216 };
217 let result = CotFunc::new()
218 .invoke_with_args(args)
219 .expect("failed to initialize function cot");
220
221 match result {
222 ColumnarValue::Array(arr) => {
223 let floats = as_float64_array(&arr)
224 .expect("failed to convert result to a Float64Array");
225
226 let expected = Float64Array::from(vec![
227 -1.986_458_685_881_4,
228 -0.156_119_952_161_6,
229 -0.501_202_783_380_1,
230 0.156_119_952_161_6,
231 ]);
232
233 let eps = 1e-12;
234 assert_eq!(floats.len(), 4);
235 assert!((floats.value(0) - expected.value(0)).abs() < eps);
236 assert!((floats.value(1) - expected.value(1)).abs() < eps);
237 assert!((floats.value(2) - expected.value(2)).abs() < eps);
238 assert!((floats.value(3) - expected.value(3)).abs() < eps);
239 }
240 ColumnarValue::Scalar(_) => {
241 panic!("Expected an array value")
242 }
243 }
244 }
245
246 #[test]
247 fn test_cot_scalar_f64() {
248 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
249 let args = ScalarFunctionArgs {
250 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))],
251 arg_fields,
252 number_rows: 1,
253 return_field: Field::new("f", DataType::Float64, false).into(),
254 config_options: Arc::new(ConfigOptions::default()),
255 };
256 let result = CotFunc::new()
257 .invoke_with_args(args)
258 .expect("cot scalar should succeed");
259
260 match result {
261 ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
262 let expected = 1.0_f64 / 1.0_f64.tan();
264 assert!((v - expected).abs() < 1e-12);
265 }
266 _ => panic!("Expected Float64 scalar"),
267 }
268 }
269
270 #[test]
271 fn test_cot_scalar_f32() {
272 let arg_fields = vec![Field::new("a", DataType::Float32, false).into()];
273 let args = ScalarFunctionArgs {
274 args: vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))],
275 arg_fields,
276 number_rows: 1,
277 return_field: Field::new("f", DataType::Float32, false).into(),
278 config_options: Arc::new(ConfigOptions::default()),
279 };
280 let result = CotFunc::new()
281 .invoke_with_args(args)
282 .expect("cot scalar should succeed");
283
284 match result {
285 ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => {
286 let expected = 1.0_f32 / 1.0_f32.tan();
287 assert!((v - expected).abs() < 1e-6);
288 }
289 _ => panic!("Expected Float32 scalar"),
290 }
291 }
292
293 #[test]
294 fn test_cot_scalar_null() {
295 let arg_fields = vec![Field::new("a", DataType::Float64, true).into()];
296 let args = ScalarFunctionArgs {
297 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(None))],
298 arg_fields,
299 number_rows: 1,
300 return_field: Field::new("f", DataType::Float64, true).into(),
301 config_options: Arc::new(ConfigOptions::default()),
302 };
303 let result = CotFunc::new()
304 .invoke_with_args(args)
305 .expect("cot null should succeed");
306
307 match result {
308 ColumnarValue::Scalar(scalar) => {
309 assert!(scalar.is_null());
310 }
311 _ => panic!("Expected scalar result"),
312 }
313 }
314
315 #[test]
316 fn test_cot_scalar_zero() {
317 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
318 let args = ScalarFunctionArgs {
319 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))],
320 arg_fields,
321 number_rows: 1,
322 return_field: Field::new("f", DataType::Float64, false).into(),
323 config_options: Arc::new(ConfigOptions::default()),
324 };
325 let result = CotFunc::new()
326 .invoke_with_args(args)
327 .expect("cot zero should succeed");
328
329 match result {
330 ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
331 assert!(v.is_infinite());
333 }
334 _ => panic!("Expected Float64 scalar"),
335 }
336 }
337
338 #[test]
339 fn test_cot_scalar_pi() {
340 let arg_fields = vec![Field::new("a", DataType::Float64, false).into()];
341 let args = ScalarFunctionArgs {
342 args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(
343 std::f64::consts::PI,
344 )))],
345 arg_fields,
346 number_rows: 1,
347 return_field: Field::new("f", DataType::Float64, false).into(),
348 config_options: Arc::new(ConfigOptions::default()),
349 };
350 let result = CotFunc::new()
351 .invoke_with_args(args)
352 .expect("cot pi should succeed");
353
354 match result {
355 ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => {
356 let expected = 1.0_f64 / std::f64::consts::PI.tan();
358 assert!((v - expected).abs() < 1e-6);
359 }
360 _ => panic!("Expected Float64 scalar"),
361 }
362 }
363}