moduforge_rules_expression/functions/
custom.rs1use crate::functions::defs::{
6 FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use std::rc::Rc;
11use std::sync::Arc;
12use std::collections::HashMap;
13use std::cell::RefCell;
14use std::fmt::Display;
15use anyhow::Result as AnyhowResult;
16use std::any::Any;
17use std::marker::PhantomData;
18
19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
21pub struct CustomFunction {
22 pub name: String,
24}
25
26impl CustomFunction {
27 pub fn new(name: String) -> Self {
28 Self { name }
29 }
30}
31
32impl Display for CustomFunction {
33 fn fmt(
34 &self,
35 f: &mut std::fmt::Formatter<'_>,
36 ) -> std::fmt::Result {
37 write!(f, "{}", self.name)
38 }
39}
40
41impl TryFrom<&str> for CustomFunction {
42 type Error = strum::ParseError;
43
44 fn try_from(value: &str) -> Result<Self, Self::Error> {
45 if CustomFunctionRegistry::is_registered(value) {
47 Ok(CustomFunction::new(value.to_string()))
48 } else {
49 Err(strum::ParseError::VariantNotFound)
50 }
51 }
52}
53
54type ErasedExecutor = Box<dyn Fn(&Arguments, Option<&Arc<dyn Any + Send + Sync>>) -> AnyhowResult<Variable> + 'static>;
56
57pub struct CustomFunctionDefinition {
59 pub name: String,
61 pub signature: FunctionSignature,
63 pub executor: ErasedExecutor,
65}
66
67impl CustomFunctionDefinition {
68 pub fn new(
69 name: String,
70 signature: FunctionSignature,
71 executor: ErasedExecutor,
72 ) -> Self {
73 Self { name, signature, executor }
74 }
75}
76
77impl FunctionDefinition for CustomFunctionDefinition {
78 fn call(
79 &self,
80 args: Arguments,
81 ) -> AnyhowResult<Variable> {
82 let state = CURRENT_STATE.with(|s| s.borrow().clone());
84 (self.executor)(&args, state.as_ref())
85 }
86
87 fn required_parameters(&self) -> usize {
88 self.signature.parameters.len()
89 }
90
91 fn optional_parameters(&self) -> usize {
92 0 }
94
95 fn check_types(
96 &self,
97 args: &[Rc<VariableType>],
98 ) -> crate::functions::defs::FunctionTypecheck {
99 let mut typecheck =
100 crate::functions::defs::FunctionTypecheck::default();
101 typecheck.return_type = self.signature.return_type.clone();
102
103 if args.len() != self.required_parameters() {
104 typecheck.general = Some(format!(
105 "期望 `{}` 参数, 实际 `{}` 参数.",
106 self.required_parameters(),
107 args.len()
108 ));
109 }
110
111 for (i, (arg, expected_type)) in
113 args.iter().zip(self.signature.parameters.iter()).enumerate()
114 {
115 if !arg.satisfies(expected_type) {
116 typecheck.arguments.push((
117 i,
118 format!(
119 "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
120 ),
121 ));
122 }
123 }
124
125 typecheck
126 }
127
128 fn param_type(
129 &self,
130 index: usize,
131 ) -> Option<VariableType> {
132 self.signature.parameters.get(index).cloned()
133 }
134
135 fn param_type_str(
136 &self,
137 index: usize,
138 ) -> String {
139 self.signature
140 .parameters
141 .get(index)
142 .map(|x| x.to_string())
143 .unwrap_or_else(|| "never".to_string())
144 }
145
146 fn return_type(&self) -> VariableType {
147 self.signature.return_type.clone()
148 }
149
150 fn return_type_str(&self) -> String {
151 self.signature.return_type.to_string()
152 }
153}
154
155thread_local! {
156 static CURRENT_STATE: RefCell<Option<Arc<dyn Any + Send + Sync>>> = RefCell::new(None);
158}
159
160pub struct CustomFunctionRegistry {
162 functions: HashMap<String, Rc<CustomFunctionDefinition>>,
163}
164
165impl CustomFunctionRegistry {
166 thread_local!(
167 static INSTANCE: RefCell<CustomFunctionRegistry> = RefCell::new(CustomFunctionRegistry::new())
168 );
169
170 fn new() -> Self {
171 Self { functions: HashMap::new() }
172 }
173
174 fn register_function_erased(
176 name: String,
177 signature: FunctionSignature,
178 executor: ErasedExecutor,
179 ) -> Result<(), String> {
180 Self::INSTANCE.with(|registry| {
181 let mut reg = registry.borrow_mut();
182 if reg.functions.contains_key(&name) {
183 return Err(format!("函数 '{}' 已经存在", name));
184 }
185
186 let definition = CustomFunctionDefinition::new(
187 name.clone(),
188 signature,
189 executor,
190 );
191 reg.functions.insert(name, Rc::new(definition));
192 Ok(())
193 })
194 }
195
196 pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
198 Self::INSTANCE.with(|registry| {
199 registry
200 .borrow()
201 .functions
202 .get(name)
203 .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
204 })
205 }
206
207 pub fn is_registered(name: &str) -> bool {
209 Self::INSTANCE
210 .with(|registry| registry.borrow().functions.contains_key(name))
211 }
212
213 pub fn set_current_state<S: Send + Sync + 'static>(state: Option<Arc<S>>) {
215 CURRENT_STATE.with(|s| {
216 *s.borrow_mut() = state.map(|st| st as Arc<dyn Any + Send + Sync>);
217 });
218 }
219
220 pub fn has_current_state() -> bool {
222 CURRENT_STATE.with(|s| s.borrow().is_some())
223 }
224
225 pub fn clear_current_state() {
227 CURRENT_STATE.with(|s| {
228 *s.borrow_mut() = None;
229 });
230 }
231
232 pub fn list_functions() -> Vec<String> {
234 Self::INSTANCE.with(|registry| {
235 registry.borrow().functions.keys().cloned().collect()
236 })
237 }
238
239 pub fn clear() {
241 Self::INSTANCE.with(|registry| {
242 registry.borrow_mut().functions.clear();
243 });
244 }
245}
246
247pub struct CustomFunctionHelper<S> {
249 _marker: PhantomData<S>,
250}
251
252impl<S: Send + Sync + 'static> CustomFunctionHelper<S> {
253 pub fn new() -> Self {
255 Self { _marker: PhantomData }
256 }
257
258 pub fn register_function(
266 &self,
267 name: String,
268 params: Vec<VariableType>,
269 return_type: VariableType,
270 executor: Box<dyn Fn(&Arguments, Option<&S>) -> AnyhowResult<Variable> + 'static>,
271 ) -> Result<(), String> {
272 let signature = FunctionSignature {
273 parameters: params,
274 return_type,
275 };
276
277 let wrapped_executor: ErasedExecutor = Box::new(move |args, state_any| {
278 let typed_state = state_any.and_then(|s| s.downcast_ref::<S>());
279 executor(args, typed_state)
280 });
281
282 CustomFunctionRegistry::register_function_erased(name, signature, wrapped_executor)
283 }
284}
285
286impl<S: Send + Sync + 'static> Default for CustomFunctionHelper<S> {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292impl From<&CustomFunction> for Rc<dyn FunctionDefinition> {
293 fn from(custom: &CustomFunction) -> Self {
294 CustomFunctionRegistry::get_definition(&custom.name).unwrap_or_else(
295 || {
296 Rc::new(StaticFunction {
298 signature: FunctionSignature {
299 parameters: vec![],
300 return_type: VariableType::Null,
301 },
302 implementation: Rc::new(|_| {
303 Err(anyhow::anyhow!("自定义函数未找到"))
304 }),
305 })
306 },
307 )
308 }
309}