mf_expression/functions/
defs.rs

1//! 函数定义接口模块
2//!
3//! 定义了函数的基本接口和具体实现类型,包括静态函数和复合函数
4
5use crate::functions::arguments::Arguments;
6use crate::variable::VariableType;
7use crate::Variable;
8use std::any::Any;
9use std::collections::HashSet;
10use std::rc::Rc;
11
12/// 函数定义特征
13///
14/// 所有函数(内置、自定义、已废弃)都必须实现此特征
15pub trait FunctionDefinition: Any {
16    /// 返回必需参数的数量
17    fn required_parameters(&self) -> usize;
18    /// 返回可选参数的数量
19    fn optional_parameters(&self) -> usize;
20    /// 检查参数类型是否匹配
21    fn check_types(
22        &self,
23        args: &[Rc<VariableType>],
24    ) -> FunctionTypecheck;
25    /// 执行函数调用
26    fn call(
27        &self,
28        args: Arguments,
29    ) -> anyhow::Result<Variable>;
30    /// 获取指定位置参数的类型
31    fn param_type(
32        &self,
33        index: usize,
34    ) -> Option<VariableType>;
35    /// 获取指定位置参数类型的字符串表示
36    fn param_type_str(
37        &self,
38        index: usize,
39    ) -> String;
40    /// 获取函数返回值类型
41    fn return_type(&self) -> VariableType;
42    /// 获取函数返回值类型的字符串表示
43    fn return_type_str(&self) -> String;
44}
45
46/// 函数类型检查结果
47///
48/// 包含类型检查过程中发现的错误信息和推断的返回类型
49#[derive(Debug, Default)]
50pub struct FunctionTypecheck {
51    /// 通用错误信息(如参数数量不匹配)
52    pub general: Option<String>,
53    /// 参数错误列表:(参数索引, 错误信息)
54    pub arguments: Vec<(usize, String)>,
55    /// 推断的返回类型
56    pub return_type: VariableType,
57}
58
59/// 函数签名
60///
61/// 描述函数的参数类型和返回类型
62#[derive(Clone)]
63pub struct FunctionSignature {
64    /// 参数类型列表
65    pub parameters: Vec<VariableType>,
66    /// 返回类型
67    pub return_type: VariableType,
68}
69
70impl FunctionSignature {
71    /// 创建单参数函数签名
72    ///
73    /// # 参数
74    /// * `parameter` - 参数类型
75    /// * `return_type` - 返回类型
76    pub fn single(
77        parameter: VariableType,
78        return_type: VariableType,
79    ) -> Self {
80        Self { parameters: vec![parameter], return_type }
81    }
82}
83
84/// 静态函数
85///
86/// 具有固定签名的函数实现
87#[derive(Clone)]
88pub struct StaticFunction {
89    /// 函数签名
90    pub signature: FunctionSignature,
91    /// 函数实现
92    pub implementation: Rc<dyn Fn(Arguments) -> anyhow::Result<Variable>>,
93}
94
95impl FunctionDefinition for StaticFunction {
96    /// 返回必需参数数量(等于签名中的参数数量)
97    fn required_parameters(&self) -> usize {
98        self.signature.parameters.len()
99    }
100
101    /// 静态函数没有可选参数
102    fn optional_parameters(&self) -> usize {
103        0
104    }
105
106    /// 检查参数类型是否与签名匹配
107    fn check_types(
108        &self,
109        args: &[Rc<VariableType>],
110    ) -> FunctionTypecheck {
111        let mut typecheck = FunctionTypecheck::default();
112        typecheck.return_type = self.signature.return_type.clone();
113
114        // 检查参数数量
115        if args.len() != self.required_parameters() {
116            typecheck.general = Some(format!(
117                "期望 `{}` 参数, 实际 `{}` 参数.",
118                self.required_parameters(),
119                args.len()
120            ));
121        }
122
123        // 检查每个参数类型
124        for (i, (arg, expected_type)) in
125            args.iter().zip(self.signature.parameters.iter()).enumerate()
126        {
127            if !arg.satisfies(expected_type) {
128                typecheck.arguments.push((
129                    i,
130                    format!(
131                        "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
132                    ),
133                ));
134            }
135        }
136
137        typecheck
138    }
139
140    /// 执行函数调用
141    fn call(
142        &self,
143        args: Arguments,
144    ) -> anyhow::Result<Variable> {
145        (&self.implementation)(args)
146    }
147
148    /// 获取指定位置的参数类型
149    fn param_type(
150        &self,
151        index: usize,
152    ) -> Option<VariableType> {
153        self.signature.parameters.get(index).cloned()
154    }
155
156    /// 获取指定位置参数类型的字符串表示
157    fn param_type_str(
158        &self,
159        index: usize,
160    ) -> String {
161        self.signature
162            .parameters
163            .get(index)
164            .map(|x| x.to_string())
165            .unwrap_or_else(|| "never".to_string())
166    }
167
168    /// 获取返回类型
169    fn return_type(&self) -> VariableType {
170        self.signature.return_type.clone()
171    }
172
173    /// 获取返回类型的字符串表示
174    fn return_type_str(&self) -> String {
175        self.signature.return_type.to_string()
176    }
177}
178
179/// 复合函数
180///
181/// 支持多个函数重载的函数实现
182#[derive(Clone)]
183pub struct CompositeFunction {
184    /// 函数重载签名列表
185    pub signatures: Vec<FunctionSignature>,
186    /// 函数实现(需要根据参数类型选择合适的重载)
187    pub implementation: Rc<dyn Fn(Arguments) -> anyhow::Result<Variable>>,
188}
189
190impl FunctionDefinition for CompositeFunction {
191    /// 返回最少参数数量(所有重载中参数最少的)
192    fn required_parameters(&self) -> usize {
193        self.signatures.iter().map(|x| x.parameters.len()).min().unwrap_or(0)
194    }
195
196    /// 返回可选参数数量(最多参数数量 - 最少参数数量)
197    fn optional_parameters(&self) -> usize {
198        let required_params = self.required_parameters();
199        let max = self
200            .signatures
201            .iter()
202            .map(|x| x.parameters.len())
203            .max()
204            .unwrap_or(0);
205
206        max - required_params
207    }
208
209    /// 检查参数类型是否匹配任一重载
210    fn check_types(
211        &self,
212        args: &[Rc<VariableType>],
213    ) -> FunctionTypecheck {
214        let mut typecheck = FunctionTypecheck::default();
215        if self.signatures.is_empty() {
216            typecheck.general = Some("No implementation".to_string());
217            return typecheck;
218        }
219
220        let required_params = self.required_parameters();
221        let optional_params = self.optional_parameters();
222        let total_params = required_params + optional_params;
223
224        // 检查参数数量是否在允许范围内
225        if args.len() < required_params || args.len() > total_params {
226            typecheck.general = Some(format!(
227                "Expected `{required_params} - {total_params}` arguments, got `{}`.",
228                args.len()
229            ))
230        }
231
232        // 查找完全匹配的重载
233        for signature in &self.signatures {
234            let all_match = args
235                .iter()
236                .zip(signature.parameters.iter())
237                .all(|(arg, param)| arg.satisfies(param));
238            if all_match {
239                typecheck.return_type = signature.return_type.clone();
240                return typecheck;
241            }
242        }
243
244        // 检查每个参数位置的类型错误
245        for (i, arg) in args.iter().enumerate() {
246            let possible_types: Vec<&VariableType> = self
247                .signatures
248                .iter()
249                .filter_map(|sig| sig.parameters.get(i))
250                .collect();
251
252            if !possible_types.iter().any(|param| arg.satisfies(param)) {
253                let type_union = self.param_type_str(i);
254                typecheck.arguments.push((
255                    i,
256                    format!(
257                        "Argument of type `{arg}` is not assignable to parameter of type `{type_union}`.",
258                    ),
259                ))
260            }
261        }
262
263        // 生成可用重载的错误信息
264        let available_signatures = self
265            .signatures
266            .iter()
267            .map(|sig| {
268                let param_list = sig
269                    .parameters
270                    .iter()
271                    .map(|x| x.to_string())
272                    .collect::<Vec<_>>()
273                    .join(", ");
274                format!("`({param_list}) -> {}`", sig.return_type)
275            })
276            .collect::<Vec<_>>()
277            .join("\n");
278        typecheck.general = Some(format!(
279            "No function overload matches provided arguments. Available overloads:\n{available_signatures}"
280        ));
281
282        typecheck
283    }
284
285    /// 执行函数调用
286    fn call(
287        &self,
288        args: Arguments,
289    ) -> anyhow::Result<Variable> {
290        (&self.implementation)(args)
291    }
292
293    /// 获取指定位置的参数类型(所有重载中该位置类型的并集)
294    fn param_type(
295        &self,
296        index: usize,
297    ) -> Option<VariableType> {
298        self.signatures
299            .iter()
300            .filter_map(|sig| sig.parameters.get(index))
301            .cloned()
302            .reduce(|a, b| a.merge(&b))
303    }
304
305    /// 获取指定位置参数类型的字符串表示(包含所有可能的类型)
306    fn param_type_str(
307        &self,
308        index: usize,
309    ) -> String {
310        let possible_types: Vec<String> = self
311            .signatures
312            .iter()
313            .filter_map(|sig| sig.parameters.get(index))
314            .map(|x| x.to_string())
315            .collect();
316        if possible_types.is_empty() {
317            return String::from("never");
318        }
319
320        let is_optional = possible_types.len() != self.signatures.len();
321        let possible_types: Vec<String> = possible_types
322            .into_iter()
323            .collect::<HashSet<_>>()
324            .into_iter()
325            .collect();
326
327        let type_union = possible_types.join(" | ");
328        if is_optional {
329            return format!("Optional<{type_union}>");
330        }
331
332        type_union
333    }
334
335    /// 获取返回类型(所有重载返回类型的并集)
336    fn return_type(&self) -> VariableType {
337        self.signatures
338            .iter()
339            .map(|sig| &sig.return_type)
340            .cloned()
341            .reduce(|a, b| a.merge(&b))
342            .unwrap_or(VariableType::Null)
343    }
344
345    /// 获取返回类型的字符串表示(包含所有可能的返回类型)
346    fn return_type_str(&self) -> String {
347        let possible_types: Vec<String> = self
348            .signatures
349            .iter()
350            .map(|sig| sig.return_type.clone())
351            .map(|x| x.to_string())
352            .collect();
353        if possible_types.is_empty() {
354            return String::from("never");
355        }
356
357        possible_types
358            .into_iter()
359            .collect::<HashSet<_>>()
360            .into_iter()
361            .collect::<Vec<_>>()
362            .join(" | ")
363    }
364}