moduforge_rules_expression/functions/
custom.rs

1//! 自定义函数模块
2//!
3//! 支持在运行时动态注册自定义函数,并可以访问State
4
5use crate::functions::defs::{
6    FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use std::rc::Rc;
11use std::sync::Arc;
12use std::collections::HashMap;
13use std::cell::RefCell;
14use std::fmt::Display;
15use anyhow::Result as AnyhowResult;
16use std::any::Any;
17use std::marker::PhantomData;
18
19/// 自定义函数标识符
20#[derive(Debug, PartialEq, Eq, Clone, Hash)]
21pub struct CustomFunction {
22    /// 函数名称
23    pub name: String,
24}
25
26impl CustomFunction {
27    pub fn new(name: String) -> Self {
28        Self { name }
29    }
30}
31
32impl Display for CustomFunction {
33    fn fmt(
34        &self,
35        f: &mut std::fmt::Formatter<'_>,
36    ) -> std::fmt::Result {
37        write!(f, "{}", self.name)
38    }
39}
40
41impl TryFrom<&str> for CustomFunction {
42    type Error = strum::ParseError;
43
44    fn try_from(value: &str) -> Result<Self, Self::Error> {
45        // 检查是否为已注册的自定义函数
46        if CustomFunctionRegistry::is_registered(value) {
47            Ok(CustomFunction::new(value.to_string()))
48        } else {
49            Err(strum::ParseError::VariantNotFound)
50        }
51    }
52}
53
54/// 自定义函数的内部执行器类型 (类型擦除)
55type ErasedExecutor = Box<dyn Fn(&Arguments, Option<&Arc<dyn Any + Send + Sync>>) -> AnyhowResult<Variable> + 'static>;
56
57/// 自定义函数定义
58pub struct CustomFunctionDefinition {
59    /// 函数名称
60    pub name: String,
61    /// 函数签名
62    pub signature: FunctionSignature,
63    /// 执行器
64    pub executor: ErasedExecutor,
65}
66
67impl CustomFunctionDefinition {
68    pub fn new(
69        name: String,
70        signature: FunctionSignature,
71        executor: ErasedExecutor,
72    ) -> Self {
73        Self { name, signature, executor }
74    }
75}
76
77impl FunctionDefinition for CustomFunctionDefinition {
78    fn call(
79        &self,
80        args: Arguments,
81    ) -> AnyhowResult<Variable> {
82        // 尝试获取State上下文(如果可用)
83        let state = CURRENT_STATE.with(|s| s.borrow().clone());
84        (self.executor)(&args, state.as_ref())
85    }
86
87    fn required_parameters(&self) -> usize {
88        self.signature.parameters.len()
89    }
90
91    fn optional_parameters(&self) -> usize {
92        0 // 暂时不支持可选参数
93    }
94
95    fn check_types(
96        &self,
97        args: &[Rc<VariableType>],
98    ) -> crate::functions::defs::FunctionTypecheck {
99        let mut typecheck =
100            crate::functions::defs::FunctionTypecheck::default();
101        typecheck.return_type = self.signature.return_type.clone();
102
103        if args.len() != self.required_parameters() {
104            typecheck.general = Some(format!(
105                "期望 `{}` 参数, 实际 `{}` 参数.",
106                self.required_parameters(),
107                args.len()
108            ));
109        }
110
111        // 检查每个参数类型
112        for (i, (arg, expected_type)) in
113            args.iter().zip(self.signature.parameters.iter()).enumerate()
114        {
115            if !arg.satisfies(expected_type) {
116                typecheck.arguments.push((
117                    i,
118                    format!(
119                        "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
120                    ),
121                ));
122            }
123        }
124
125        typecheck
126    }
127
128    fn param_type(
129        &self,
130        index: usize,
131    ) -> Option<VariableType> {
132        self.signature.parameters.get(index).cloned()
133    }
134
135    fn param_type_str(
136        &self,
137        index: usize,
138    ) -> String {
139        self.signature
140            .parameters
141            .get(index)
142            .map(|x| x.to_string())
143            .unwrap_or_else(|| "never".to_string())
144    }
145
146    fn return_type(&self) -> VariableType {
147        self.signature.return_type.clone()
148    }
149
150    fn return_type_str(&self) -> String {
151        self.signature.return_type.to_string()
152    }
153}
154
155thread_local! {
156    /// 当前State上下文(用于自定义函数访问)
157    static CURRENT_STATE: RefCell<Option<Arc<dyn Any + Send + Sync>>> = RefCell::new(None);
158}
159
160/// 自定义函数注册表
161pub struct CustomFunctionRegistry {
162    functions: HashMap<String, Rc<CustomFunctionDefinition>>,
163}
164
165impl CustomFunctionRegistry {
166    thread_local!(
167        static INSTANCE: RefCell<CustomFunctionRegistry> = RefCell::new(CustomFunctionRegistry::new())
168    );
169
170    fn new() -> Self {
171        Self { functions: HashMap::new() }
172    }
173
174    /// 注册自定义函数 (内部使用)
175    fn register_function_erased(
176        name: String,
177        signature: FunctionSignature,
178        executor: ErasedExecutor,
179    ) -> Result<(), String> {
180        Self::INSTANCE.with(|registry| {
181            let mut reg = registry.borrow_mut();
182            if reg.functions.contains_key(&name) {
183                return Err(format!("函数 '{}' 已经存在", name));
184            }
185
186            let definition = CustomFunctionDefinition::new(
187                name.clone(),
188                signature,
189                executor,
190            );
191            reg.functions.insert(name, Rc::new(definition));
192            Ok(())
193        })
194    }
195
196    /// 获取函数定义
197    pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
198        Self::INSTANCE.with(|registry| {
199            registry
200                .borrow()
201                .functions
202                .get(name)
203                .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
204        })
205    }
206
207    /// 检查函数是否已注册
208    pub fn is_registered(name: &str) -> bool {
209        Self::INSTANCE
210            .with(|registry| registry.borrow().functions.contains_key(name))
211    }
212
213    /// 设置当前State上下文
214    pub fn set_current_state<S: Send + Sync + 'static>(state: Option<Arc<S>>) {
215        CURRENT_STATE.with(|s| {
216            *s.borrow_mut() = state.map(|st| st as Arc<dyn Any + Send + Sync>);
217        });
218    }
219
220    /// 检查当前是否有活跃的State
221    pub fn has_current_state() -> bool {
222        CURRENT_STATE.with(|s| s.borrow().is_some())
223    }
224
225    /// 清理当前State上下文
226    pub fn clear_current_state() {
227        CURRENT_STATE.with(|s| {
228            *s.borrow_mut() = None;
229        });
230    }
231
232    /// 列出所有已注册的函数
233    pub fn list_functions() -> Vec<String> {
234        Self::INSTANCE.with(|registry| {
235            registry.borrow().functions.keys().cloned().collect()
236        })
237    }
238
239    /// 清空所有注册的函数
240    pub fn clear() {
241        Self::INSTANCE.with(|registry| {
242            registry.borrow_mut().functions.clear();
243        });
244    }
245}
246
247/// 用于注册特定状态类型 `S` 的函数的辅助结构。
248pub struct CustomFunctionHelper<S> {
249    _marker: PhantomData<S>,
250}
251
252impl<S: Send + Sync + 'static> CustomFunctionHelper<S> {
253    /// 创建一个新的辅助实例。
254    pub fn new() -> Self {
255        Self { _marker: PhantomData }
256    }
257
258    /// 注册一个自定义函数。
259    ///
260    /// # Parameters
261    /// - `name`: 函数名。
262    /// - `params`: 函数参数类型列表。
263    /// - `return_type`: 函数返回类型。
264    /// - `executor`: 函数的实现,它接收参数和一个可选的 `Arc<S>` 状态引用。
265    pub fn register_function(
266        &self,
267        name: String,
268        params: Vec<VariableType>,
269        return_type: VariableType,
270        executor: Box<dyn Fn(&Arguments, Option<&S>) -> AnyhowResult<Variable> + 'static>,
271    ) -> Result<(), String> {
272        let signature = FunctionSignature {
273            parameters: params,
274            return_type,
275        };
276
277        let wrapped_executor: ErasedExecutor = Box::new(move |args, state_any| {
278            let typed_state = state_any.and_then(|s| s.downcast_ref::<S>());
279            executor(args, typed_state)
280        });
281
282        CustomFunctionRegistry::register_function_erased(name, signature, wrapped_executor)
283    }
284}
285
286impl<S: Send + Sync + 'static> Default for CustomFunctionHelper<S> {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292impl From<&CustomFunction> for Rc<dyn FunctionDefinition> {
293    fn from(custom: &CustomFunction) -> Self {
294        CustomFunctionRegistry::get_definition(&custom.name).unwrap_or_else(
295            || {
296                // 如果函数不存在,返回一个错误函数
297                Rc::new(StaticFunction {
298                    signature: FunctionSignature {
299                        parameters: vec![],
300                        return_type: VariableType::Null,
301                    },
302                    implementation: Rc::new(|_| {
303                        Err(anyhow::anyhow!("自定义函数未找到"))
304                    }),
305                })
306            },
307        )
308    }
309}