datafusion_python/common/
function.rs1use std::collections::HashMap;
19
20use datafusion::arrow::datatypes::DataType;
21use pyo3::prelude::*;
22
23use super::data_type::PyDataType;
24
25#[pyclass(
26 from_py_object,
27 frozen,
28 name = "SqlFunction",
29 module = "datafusion.common",
30 subclass
31)]
32#[derive(Debug, Clone)]
33pub struct SqlFunction {
34 pub name: String,
35 pub return_types: HashMap<Vec<DataType>, DataType>,
36 pub aggregation: bool,
37}
38
39impl SqlFunction {
40 pub fn new(
41 function_name: String,
42 input_types: Vec<PyDataType>,
43 return_type: PyDataType,
44 aggregation_bool: bool,
45 ) -> Self {
46 let mut func = Self {
47 name: function_name,
48 return_types: HashMap::new(),
49 aggregation: aggregation_bool,
50 };
51 func.add_type_mapping(input_types, return_type);
52 func
53 }
54
55 pub fn add_type_mapping(&mut self, input_types: Vec<PyDataType>, return_type: PyDataType) {
56 self.return_types.insert(
57 input_types.iter().map(|t| t.clone().into()).collect(),
58 return_type.into(),
59 );
60 }
61}