moduforge_rules_expression/functions/
custom.rs

1//! 自定义函数模块
2//!
3//! 支持在运行时动态注册自定义函数,并可以访问State
4
5use crate::functions::defs::{
6    FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use moduforge_state::State;
11use std::rc::Rc;
12use std::sync::Arc;
13use std::collections::HashMap;
14use std::cell::RefCell;
15use std::fmt::Display;
16use anyhow::Result as AnyhowResult;
17
18/// 自定义函数标识符
19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
20pub struct CustomFunction {
21    /// 函数名称
22    pub name: String,
23}
24
25impl CustomFunction {
26    pub fn new(name: String) -> Self {
27        Self { name }
28    }
29}
30
31impl Display for CustomFunction {
32    fn fmt(
33        &self,
34        f: &mut std::fmt::Formatter<'_>,
35    ) -> std::fmt::Result {
36        write!(f, "{}", self.name)
37    }
38}
39
40impl TryFrom<&str> for CustomFunction {
41    type Error = strum::ParseError;
42
43    fn try_from(value: &str) -> Result<Self, Self::Error> {
44        // 检查是否为已注册的自定义函数
45        if CustomFunctionRegistry::is_registered(value) {
46            Ok(CustomFunction::new(value.to_string()))
47        } else {
48            Err(strum::ParseError::VariantNotFound)
49        }
50    }
51}
52
53/// 自定义函数的执行器类型
54pub type CustomFunctionExecutor = Box<
55    dyn Fn(&Arguments, Option<&Arc<State>>) -> AnyhowResult<Variable> + 'static,
56>;
57
58/// 自定义函数定义
59pub struct CustomFunctionDefinition {
60    /// 函数名称
61    pub name: String,
62    /// 函数签名
63    pub signature: FunctionSignature,
64    /// 执行器
65    pub executor: CustomFunctionExecutor,
66}
67
68impl CustomFunctionDefinition {
69    pub fn new(
70        name: String,
71        signature: FunctionSignature,
72        executor: CustomFunctionExecutor,
73    ) -> Self {
74        Self { name, signature, executor }
75    }
76}
77
78impl FunctionDefinition for CustomFunctionDefinition {
79    fn call(
80        &self,
81        args: Arguments,
82    ) -> AnyhowResult<Variable> {
83        // 尝试获取State上下文(如果可用)
84        let state = CURRENT_STATE.with(|s| s.borrow().clone());
85        (self.executor)(&args, state.as_ref())
86    }
87
88    fn required_parameters(&self) -> usize {
89        self.signature.parameters.len()
90    }
91
92    fn optional_parameters(&self) -> usize {
93        0 // 暂时不支持可选参数
94    }
95
96    fn check_types(
97        &self,
98        args: &[Rc<VariableType>],
99    ) -> crate::functions::defs::FunctionTypecheck {
100        let mut typecheck =
101            crate::functions::defs::FunctionTypecheck::default();
102        typecheck.return_type = self.signature.return_type.clone();
103
104        if args.len() != self.required_parameters() {
105            typecheck.general = Some(format!(
106                "期望 `{}` 参数, 实际 `{}` 参数.",
107                self.required_parameters(),
108                args.len()
109            ));
110        }
111
112        // 检查每个参数类型
113        for (i, (arg, expected_type)) in
114            args.iter().zip(self.signature.parameters.iter()).enumerate()
115        {
116            if !arg.satisfies(expected_type) {
117                typecheck.arguments.push((
118                    i,
119                    format!(
120                        "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
121                    ),
122                ));
123            }
124        }
125
126        typecheck
127    }
128
129    fn param_type(
130        &self,
131        index: usize,
132    ) -> Option<VariableType> {
133        self.signature.parameters.get(index).cloned()
134    }
135
136    fn param_type_str(
137        &self,
138        index: usize,
139    ) -> String {
140        self.signature
141            .parameters
142            .get(index)
143            .map(|x| x.to_string())
144            .unwrap_or_else(|| "never".to_string())
145    }
146
147    fn return_type(&self) -> VariableType {
148        self.signature.return_type.clone()
149    }
150
151    fn return_type_str(&self) -> String {
152        self.signature.return_type.to_string()
153    }
154}
155
156thread_local! {
157    /// 当前State上下文(用于自定义函数访问)
158    static CURRENT_STATE: RefCell<Option<Arc<State>>> = RefCell::new(None);
159}
160
161/// 自定义函数注册表
162pub struct CustomFunctionRegistry {
163    functions: HashMap<String, Rc<CustomFunctionDefinition>>,
164}
165
166impl CustomFunctionRegistry {
167    thread_local!(
168        static INSTANCE: RefCell<CustomFunctionRegistry> = RefCell::new(CustomFunctionRegistry::new())
169    );
170
171    fn new() -> Self {
172        Self { functions: HashMap::new() }
173    }
174
175    /// 注册自定义函数
176    pub fn register_function(
177        name: String,
178        signature: FunctionSignature,
179        executor: CustomFunctionExecutor,
180    ) -> Result<(), String> {
181        Self::INSTANCE.with(|registry| {
182            let mut reg = registry.borrow_mut();
183            if reg.functions.contains_key(&name) {
184                return Err(format!("函数 '{}' 已经存在", name));
185            }
186
187            let definition = CustomFunctionDefinition::new(
188                name.clone(),
189                signature,
190                executor,
191            );
192            reg.functions.insert(name, Rc::new(definition));
193            Ok(())
194        })
195    }
196
197    /// 获取函数定义
198    pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
199        Self::INSTANCE.with(|registry| {
200            registry
201                .borrow()
202                .functions
203                .get(name)
204                .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
205        })
206    }
207
208    /// 检查函数是否已注册
209    pub fn is_registered(name: &str) -> bool {
210        Self::INSTANCE
211            .with(|registry| registry.borrow().functions.contains_key(name))
212    }
213
214    /// 设置当前State上下文
215    pub fn set_current_state(state: Option<Arc<State>>) {
216        CURRENT_STATE.with(|s| {
217            *s.borrow_mut() = state;
218        });
219    }
220
221    /// 检查当前是否有活跃的State
222    pub fn has_current_state() -> bool {
223        CURRENT_STATE.with(|s| s.borrow().is_some())
224    }
225
226    /// 列出所有已注册的函数
227    pub fn list_functions() -> Vec<String> {
228        Self::INSTANCE.with(|registry| {
229            registry.borrow().functions.keys().cloned().collect()
230        })
231    }
232
233    /// 清空所有注册的函数
234    pub fn clear() {
235        Self::INSTANCE.with(|registry| {
236            registry.borrow_mut().functions.clear();
237        });
238    }
239}
240
241impl From<&CustomFunction> for Rc<dyn FunctionDefinition> {
242    fn from(custom: &CustomFunction) -> Self {
243        CustomFunctionRegistry::get_definition(&custom.name).unwrap_or_else(
244            || {
245                // 如果函数不存在,返回一个错误函数
246                Rc::new(StaticFunction {
247                    signature: FunctionSignature {
248                        parameters: vec![],
249                        return_type: VariableType::Null,
250                    },
251                    implementation: Rc::new(|_| {
252                        Err(anyhow::anyhow!("自定义函数未找到"))
253                    }),
254                })
255            },
256        )
257    }
258}