datafusion_python/common/function.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19
20use datafusion::arrow::datatypes::DataType;
21use pyo3::prelude::*;
22
23use super::data_type::PyDataType;
24
25#[pyclass(name = "SqlFunction", module = "datafusion.common", subclass)]
26#[derive(Debug, Clone)]
27pub struct SqlFunction {
28 pub name: String,
29 pub return_types: HashMap<Vec<DataType>, DataType>,
30 pub aggregation: bool,
31}
32
33impl SqlFunction {
34 pub fn new(
35 function_name: String,
36 input_types: Vec<PyDataType>,
37 return_type: PyDataType,
38 aggregation_bool: bool,
39 ) -> Self {
40 let mut func = Self {
41 name: function_name,
42 return_types: HashMap::new(),
43 aggregation: aggregation_bool,
44 };
45 func.add_type_mapping(input_types, return_type);
46 func
47 }
48
49 pub fn add_type_mapping(&mut self, input_types: Vec<PyDataType>, return_type: PyDataType) {
50 self.return_types.insert(
51 input_types.iter().map(|t| t.clone().into()).collect(),
52 return_type.into(),
53 );
54 }
55}