mf_expression/functions/
mf_function.rs

1//! 自定义函数模块
2//!
3//! 支持在运行时动态注册自定义函数,并可以访问State
4
5use crate::functions::defs::{
6    FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use std::rc::Rc;
11use std::sync::Arc;
12use std::collections::HashMap;
13use std::cell::RefCell;
14use std::fmt::Display;
15use anyhow::Result as AnyhowResult;
16use std::any::Any;
17use std::marker::PhantomData;
18
19/// 自定义函数标识符
20#[derive(Debug, PartialEq, Eq, Clone, Hash)]
21pub struct MfFunction {
22    /// 函数名称
23    pub name: String,
24}
25
26impl MfFunction {
27    pub fn new(name: String) -> Self {
28        Self { name }
29    }
30}
31
32impl Display for MfFunction {
33    fn fmt(
34        &self,
35        f: &mut std::fmt::Formatter<'_>,
36    ) -> std::fmt::Result {
37        write!(f, "{}", self.name)
38    }
39}
40
41impl TryFrom<&str> for MfFunction {
42    type Error = strum::ParseError;
43
44    fn try_from(value: &str) -> Result<Self, Self::Error> {
45        // 检查是否为已注册的自定义函数
46        if MfFunctionRegistry::is_registered(value) {
47            Ok(MfFunction::new(value.to_string()))
48        } else {
49            Err(strum::ParseError::VariantNotFound)
50        }
51    }
52}
53
54/// 自定义函数的内部执行器类型 (类型擦除)
55type ErasedExecutor = Box<
56    dyn Fn(
57            &Arguments,
58            Option<&Arc<dyn Any + Send + Sync>>,
59        ) -> AnyhowResult<Variable>
60        + 'static,
61>;
62
63/// 自定义函数定义
64pub struct MfFunctionDefinition {
65    /// 函数名称
66    pub name: String,
67    /// 函数签名
68    pub signature: FunctionSignature,
69    /// 执行器
70    pub executor: ErasedExecutor,
71}
72
73impl MfFunctionDefinition {
74    pub fn new(
75        name: String,
76        signature: FunctionSignature,
77        executor: ErasedExecutor,
78    ) -> Self {
79        Self { name, signature, executor }
80    }
81}
82
83impl FunctionDefinition for MfFunctionDefinition {
84    fn call(
85        &self,
86        args: Arguments,
87    ) -> AnyhowResult<Variable> {
88        // 尝试获取State上下文(如果可用)
89        let state = CURRENT_STATE.with(|s| s.borrow().clone());
90        (self.executor)(&args, state.as_ref())
91    }
92
93    fn required_parameters(&self) -> usize {
94        self.signature.parameters.len()
95    }
96
97    fn optional_parameters(&self) -> usize {
98        0 // 暂时不支持可选参数
99    }
100
101    fn check_types(
102        &self,
103        args: &[VariableType],
104    ) -> crate::functions::defs::FunctionTypecheck {
105        let mut typecheck =
106            crate::functions::defs::FunctionTypecheck::default();
107        typecheck.return_type = self.signature.return_type.clone();
108
109        if args.len() != self.required_parameters() {
110            typecheck.general = Some(format!(
111                "期望 `{}` 参数, 实际 `{}` 参数.",
112                self.required_parameters(),
113                args.len()
114            ));
115        }
116
117        // 检查每个参数类型
118        for (i, (arg, expected_type)) in
119            args.iter().zip(self.signature.parameters.iter()).enumerate()
120        {
121            if !arg.satisfies(expected_type) {
122                typecheck.arguments.push((
123                    i,
124                    format!(
125                        "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
126                    ),
127                ));
128            }
129        }
130
131        typecheck
132    }
133
134    fn param_type(
135        &self,
136        index: usize,
137    ) -> Option<VariableType> {
138        self.signature.parameters.get(index).cloned()
139    }
140
141    fn param_type_str(
142        &self,
143        index: usize,
144    ) -> String {
145        self.signature
146            .parameters
147            .get(index)
148            .map(|x| x.to_string())
149            .unwrap_or_else(|| "never".to_string())
150    }
151
152    fn return_type(&self) -> VariableType {
153        self.signature.return_type.clone()
154    }
155
156    fn return_type_str(&self) -> String {
157        self.signature.return_type.to_string()
158    }
159}
160
161thread_local! {
162    /// 当前State上下文(用于自定义函数访问)
163    static CURRENT_STATE: RefCell<Option<Arc<dyn Any + Send + Sync>>> = RefCell::new(None);
164}
165
166/// 自定义函数注册表
167pub struct MfFunctionRegistry {
168    functions: HashMap<String, Rc<MfFunctionDefinition>>,
169}
170
171impl MfFunctionRegistry {
172    thread_local!(
173        static INSTANCE: RefCell<MfFunctionRegistry> = RefCell::new(MfFunctionRegistry::new())
174    );
175
176    fn new() -> Self {
177        Self { functions: HashMap::new() }
178    }
179
180    /// 注册自定义函数 (内部使用)
181    fn register_function_erased(
182        name: String,
183        signature: FunctionSignature,
184        executor: ErasedExecutor,
185    ) -> Result<(), String> {
186        Self::INSTANCE.with(|registry| {
187            let mut reg = registry.borrow_mut();
188            if reg.functions.contains_key(&name) {
189                return Err(format!("函数 '{}' 已经存在", name));
190            }
191
192            let definition =
193                MfFunctionDefinition::new(name.clone(), signature, executor);
194            reg.functions.insert(name, Rc::new(definition));
195            Ok(())
196        })
197    }
198
199    /// 获取函数定义
200    pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
201        Self::INSTANCE.with(|registry| {
202            registry
203                .borrow()
204                .functions
205                .get(name)
206                .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
207        })
208    }
209
210    /// 检查函数是否已注册
211    pub fn is_registered(name: &str) -> bool {
212        Self::INSTANCE
213            .with(|registry| registry.borrow().functions.contains_key(name))
214    }
215
216    /// 设置当前State上下文
217    pub fn set_current_state<S: Send + Sync + 'static>(state: Option<Arc<S>>) {
218        CURRENT_STATE.with(|s| {
219            *s.borrow_mut() = state.map(|st| st as Arc<dyn Any + Send + Sync>);
220        });
221    }
222
223    /// 检查当前是否有活跃的State
224    pub fn has_current_state() -> bool {
225        CURRENT_STATE.with(|s| s.borrow().is_some())
226    }
227
228    /// 清理当前State上下文
229    pub fn clear_current_state() {
230        CURRENT_STATE.with(|s| {
231            *s.borrow_mut() = None;
232        });
233    }
234
235    /// 列出所有已注册的函数
236    pub fn list_functions() -> Vec<String> {
237        Self::INSTANCE.with(|registry| {
238            registry.borrow().functions.keys().cloned().collect()
239        })
240    }
241
242    /// 清空所有注册的函数
243    pub fn clear() {
244        Self::INSTANCE.with(|registry| {
245            registry.borrow_mut().functions.clear();
246        });
247    }
248}
249
250/// 用于注册特定状态类型 `S` 的函数的辅助结构。
251pub struct MfFunctionHelper<S> {
252    _marker: PhantomData<S>,
253}
254
255impl<S: Send + Sync + 'static> MfFunctionHelper<S> {
256    /// 创建一个新的辅助实例。
257    pub fn new() -> Self {
258        Self { _marker: PhantomData }
259    }
260
261    /// 注册一个自定义函数。
262    ///
263    /// # Parameters
264    /// - `name`: 函数名。
265    /// - `params`: 函数参数类型列表。
266    /// - `return_type`: 函数返回类型。
267    /// - `executor`: 函数的实现,它接收参数和一个可选的 `Arc<S>` 状态引用。
268    pub fn register_function(
269        &self,
270        name: String,
271        params: Vec<VariableType>,
272        return_type: VariableType,
273        executor: Box<
274            dyn Fn(&Arguments, Option<&S>) -> AnyhowResult<Variable> + 'static,
275        >,
276    ) -> Result<(), String> {
277        let signature = FunctionSignature { parameters: params, return_type };
278
279        let wrapped_executor: ErasedExecutor =
280            Box::new(move |args, state_any| {
281                let typed_state = state_any.and_then(|s| s.downcast_ref::<S>());
282                executor(args, typed_state)
283            });
284
285        MfFunctionRegistry::register_function_erased(
286            name,
287            signature,
288            wrapped_executor,
289        )
290    }
291}
292
293impl<S: Send + Sync + 'static> Default for MfFunctionHelper<S> {
294    fn default() -> Self {
295        Self::new()
296    }
297}
298
299impl From<&MfFunction> for Rc<dyn FunctionDefinition> {
300    fn from(custom: &MfFunction) -> Self {
301        MfFunctionRegistry::get_definition(&custom.name).unwrap_or_else(|| {
302            // 如果函数不存在,返回一个错误函数
303            Rc::new(StaticFunction {
304                signature: FunctionSignature {
305                    parameters: vec![],
306                    return_type: VariableType::Null,
307                },
308                implementation: Rc::new(|_| {
309                    Err(anyhow::anyhow!("自定义函数未找到"))
310                }),
311            })
312        })
313    }
314}