mf_expression/functions/
custom.rs

1//! 自定义函数模块
2//!
3//! 支持在运行时动态注册自定义函数,并可以访问State
4
5use crate::functions::defs::{
6    FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use std::rc::Rc;
11use std::sync::Arc;
12use std::collections::HashMap;
13use std::cell::RefCell;
14use std::fmt::Display;
15use anyhow::Result as AnyhowResult;
16use std::any::Any;
17use std::marker::PhantomData;
18
19/// 自定义函数标识符
20#[derive(Debug, PartialEq, Eq, Clone, Hash)]
21pub struct CustomFunction {
22    /// 函数名称
23    pub name: String,
24}
25
26impl CustomFunction {
27    pub fn new(name: String) -> Self {
28        Self { name }
29    }
30}
31
32impl Display for CustomFunction {
33    fn fmt(
34        &self,
35        f: &mut std::fmt::Formatter<'_>,
36    ) -> std::fmt::Result {
37        write!(f, "{}", self.name)
38    }
39}
40
41impl TryFrom<&str> for CustomFunction {
42    type Error = strum::ParseError;
43
44    fn try_from(value: &str) -> Result<Self, Self::Error> {
45        // 检查是否为已注册的自定义函数
46        if CustomFunctionRegistry::is_registered(value) {
47            Ok(CustomFunction::new(value.to_string()))
48        } else {
49            Err(strum::ParseError::VariantNotFound)
50        }
51    }
52}
53
54/// 自定义函数的内部执行器类型 (类型擦除)
55type ErasedExecutor = Box<
56    dyn Fn(
57            &Arguments,
58            Option<&Arc<dyn Any + Send + Sync>>,
59        ) -> AnyhowResult<Variable>
60        + 'static,
61>;
62
63/// 自定义函数定义
64pub struct CustomFunctionDefinition {
65    /// 函数名称
66    pub name: String,
67    /// 函数签名
68    pub signature: FunctionSignature,
69    /// 执行器
70    pub executor: ErasedExecutor,
71}
72
73impl CustomFunctionDefinition {
74    pub fn new(
75        name: String,
76        signature: FunctionSignature,
77        executor: ErasedExecutor,
78    ) -> Self {
79        Self { name, signature, executor }
80    }
81}
82
83impl FunctionDefinition for CustomFunctionDefinition {
84    fn call(
85        &self,
86        args: Arguments,
87    ) -> AnyhowResult<Variable> {
88        // 尝试获取State上下文(如果可用)
89        let state = CURRENT_STATE.with(|s| s.borrow().clone());
90        (self.executor)(&args, state.as_ref())
91    }
92
93    fn required_parameters(&self) -> usize {
94        self.signature.parameters.len()
95    }
96
97    fn optional_parameters(&self) -> usize {
98        0 // 暂时不支持可选参数
99    }
100
101    fn check_types(
102        &self,
103        args: &[Rc<VariableType>],
104    ) -> crate::functions::defs::FunctionTypecheck {
105        let mut typecheck =
106            crate::functions::defs::FunctionTypecheck::default();
107        typecheck.return_type = self.signature.return_type.clone();
108
109        if args.len() != self.required_parameters() {
110            typecheck.general = Some(format!(
111                "期望 `{}` 参数, 实际 `{}` 参数.",
112                self.required_parameters(),
113                args.len()
114            ));
115        }
116
117        // 检查每个参数类型
118        for (i, (arg, expected_type)) in
119            args.iter().zip(self.signature.parameters.iter()).enumerate()
120        {
121            if !arg.satisfies(expected_type) {
122                typecheck.arguments.push((
123                    i,
124                    format!(
125                        "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
126                    ),
127                ));
128            }
129        }
130
131        typecheck
132    }
133
134    fn param_type(
135        &self,
136        index: usize,
137    ) -> Option<VariableType> {
138        self.signature.parameters.get(index).cloned()
139    }
140
141    fn param_type_str(
142        &self,
143        index: usize,
144    ) -> String {
145        self.signature
146            .parameters
147            .get(index)
148            .map(|x| x.to_string())
149            .unwrap_or_else(|| "never".to_string())
150    }
151
152    fn return_type(&self) -> VariableType {
153        self.signature.return_type.clone()
154    }
155
156    fn return_type_str(&self) -> String {
157        self.signature.return_type.to_string()
158    }
159}
160
161thread_local! {
162    /// 当前State上下文(用于自定义函数访问)
163    static CURRENT_STATE: RefCell<Option<Arc<dyn Any + Send + Sync>>> = RefCell::new(None);
164}
165
166/// 自定义函数注册表
167pub struct CustomFunctionRegistry {
168    functions: HashMap<String, Rc<CustomFunctionDefinition>>,
169}
170
171impl CustomFunctionRegistry {
172    thread_local!(
173        static INSTANCE: RefCell<CustomFunctionRegistry> = RefCell::new(CustomFunctionRegistry::new())
174    );
175
176    fn new() -> Self {
177        Self { functions: HashMap::new() }
178    }
179
180    /// 注册自定义函数 (内部使用)
181    fn register_function_erased(
182        name: String,
183        signature: FunctionSignature,
184        executor: ErasedExecutor,
185    ) -> Result<(), String> {
186        Self::INSTANCE.with(|registry| {
187            let mut reg = registry.borrow_mut();
188            if reg.functions.contains_key(&name) {
189                return Err(format!("函数 '{}' 已经存在", name));
190            }
191
192            let definition = CustomFunctionDefinition::new(
193                name.clone(),
194                signature,
195                executor,
196            );
197            reg.functions.insert(name, Rc::new(definition));
198            Ok(())
199        })
200    }
201
202    /// 获取函数定义
203    pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
204        Self::INSTANCE.with(|registry| {
205            registry
206                .borrow()
207                .functions
208                .get(name)
209                .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
210        })
211    }
212
213    /// 检查函数是否已注册
214    pub fn is_registered(name: &str) -> bool {
215        Self::INSTANCE
216            .with(|registry| registry.borrow().functions.contains_key(name))
217    }
218
219    /// 设置当前State上下文
220    pub fn set_current_state<S: Send + Sync + 'static>(state: Option<Arc<S>>) {
221        CURRENT_STATE.with(|s| {
222            *s.borrow_mut() = state.map(|st| st as Arc<dyn Any + Send + Sync>);
223        });
224    }
225
226    /// 检查当前是否有活跃的State
227    pub fn has_current_state() -> bool {
228        CURRENT_STATE.with(|s| s.borrow().is_some())
229    }
230
231    /// 清理当前State上下文
232    pub fn clear_current_state() {
233        CURRENT_STATE.with(|s| {
234            *s.borrow_mut() = None;
235        });
236    }
237
238    /// 列出所有已注册的函数
239    pub fn list_functions() -> Vec<String> {
240        Self::INSTANCE.with(|registry| {
241            registry.borrow().functions.keys().cloned().collect()
242        })
243    }
244
245    /// 清空所有注册的函数
246    pub fn clear() {
247        Self::INSTANCE.with(|registry| {
248            registry.borrow_mut().functions.clear();
249        });
250    }
251}
252
253/// 用于注册特定状态类型 `S` 的函数的辅助结构。
254pub struct CustomFunctionHelper<S> {
255    _marker: PhantomData<S>,
256}
257
258impl<S: Send + Sync + 'static> CustomFunctionHelper<S> {
259    /// 创建一个新的辅助实例。
260    pub fn new() -> Self {
261        Self { _marker: PhantomData }
262    }
263
264    /// 注册一个自定义函数。
265    ///
266    /// # Parameters
267    /// - `name`: 函数名。
268    /// - `params`: 函数参数类型列表。
269    /// - `return_type`: 函数返回类型。
270    /// - `executor`: 函数的实现,它接收参数和一个可选的 `Arc<S>` 状态引用。
271    pub fn register_function(
272        &self,
273        name: String,
274        params: Vec<VariableType>,
275        return_type: VariableType,
276        executor: Box<
277            dyn Fn(&Arguments, Option<&S>) -> AnyhowResult<Variable> + 'static,
278        >,
279    ) -> Result<(), String> {
280        let signature = FunctionSignature { parameters: params, return_type };
281
282        let wrapped_executor: ErasedExecutor =
283            Box::new(move |args, state_any| {
284                let typed_state = state_any.and_then(|s| s.downcast_ref::<S>());
285                executor(args, typed_state)
286            });
287
288        CustomFunctionRegistry::register_function_erased(
289            name,
290            signature,
291            wrapped_executor,
292        )
293    }
294}
295
296impl<S: Send + Sync + 'static> Default for CustomFunctionHelper<S> {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302impl From<&CustomFunction> for Rc<dyn FunctionDefinition> {
303    fn from(custom: &CustomFunction) -> Self {
304        CustomFunctionRegistry::get_definition(&custom.name).unwrap_or_else(
305            || {
306                // 如果函数不存在,返回一个错误函数
307                Rc::new(StaticFunction {
308                    signature: FunctionSignature {
309                        parameters: vec![],
310                        return_type: VariableType::Null,
311                    },
312                    implementation: Rc::new(|_| {
313                        Err(anyhow::anyhow!("自定义函数未找到"))
314                    }),
315                })
316            },
317        )
318    }
319}