moduforge_rules_expression/functions/
custom.rs1use crate::functions::defs::{
6 FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use moduforge_state::State;
11use std::rc::Rc;
12use std::sync::Arc;
13use std::collections::HashMap;
14use std::cell::RefCell;
15use std::fmt::Display;
16use anyhow::Result as AnyhowResult;
17
18#[derive(Debug, PartialEq, Eq, Clone, Hash)]
20pub struct CustomFunction {
21 pub name: String,
23}
24
25impl CustomFunction {
26 pub fn new(name: String) -> Self {
27 Self { name }
28 }
29}
30
31impl Display for CustomFunction {
32 fn fmt(
33 &self,
34 f: &mut std::fmt::Formatter<'_>,
35 ) -> std::fmt::Result {
36 write!(f, "{}", self.name)
37 }
38}
39
40impl TryFrom<&str> for CustomFunction {
41 type Error = strum::ParseError;
42
43 fn try_from(value: &str) -> Result<Self, Self::Error> {
44 if CustomFunctionRegistry::is_registered(value) {
46 Ok(CustomFunction::new(value.to_string()))
47 } else {
48 Err(strum::ParseError::VariantNotFound)
49 }
50 }
51}
52
53pub type CustomFunctionExecutor = Box<
55 dyn Fn(&Arguments, Option<&Arc<State>>) -> AnyhowResult<Variable> + 'static,
56>;
57
58pub struct CustomFunctionDefinition {
60 pub name: String,
62 pub signature: FunctionSignature,
64 pub executor: CustomFunctionExecutor,
66}
67
68impl CustomFunctionDefinition {
69 pub fn new(
70 name: String,
71 signature: FunctionSignature,
72 executor: CustomFunctionExecutor,
73 ) -> Self {
74 Self { name, signature, executor }
75 }
76}
77
78impl FunctionDefinition for CustomFunctionDefinition {
79 fn call(
80 &self,
81 args: Arguments,
82 ) -> AnyhowResult<Variable> {
83 let state = CURRENT_STATE.with(|s| s.borrow().clone());
85 (self.executor)(&args, state.as_ref())
86 }
87
88 fn required_parameters(&self) -> usize {
89 self.signature.parameters.len()
90 }
91
92 fn optional_parameters(&self) -> usize {
93 0 }
95
96 fn check_types(
97 &self,
98 args: &[Rc<VariableType>],
99 ) -> crate::functions::defs::FunctionTypecheck {
100 let mut typecheck =
101 crate::functions::defs::FunctionTypecheck::default();
102 typecheck.return_type = self.signature.return_type.clone();
103
104 if args.len() != self.required_parameters() {
105 typecheck.general = Some(format!(
106 "期望 `{}` 参数, 实际 `{}` 参数.",
107 self.required_parameters(),
108 args.len()
109 ));
110 }
111
112 for (i, (arg, expected_type)) in
114 args.iter().zip(self.signature.parameters.iter()).enumerate()
115 {
116 if !arg.satisfies(expected_type) {
117 typecheck.arguments.push((
118 i,
119 format!(
120 "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
121 ),
122 ));
123 }
124 }
125
126 typecheck
127 }
128
129 fn param_type(
130 &self,
131 index: usize,
132 ) -> Option<VariableType> {
133 self.signature.parameters.get(index).cloned()
134 }
135
136 fn param_type_str(
137 &self,
138 index: usize,
139 ) -> String {
140 self.signature
141 .parameters
142 .get(index)
143 .map(|x| x.to_string())
144 .unwrap_or_else(|| "never".to_string())
145 }
146
147 fn return_type(&self) -> VariableType {
148 self.signature.return_type.clone()
149 }
150
151 fn return_type_str(&self) -> String {
152 self.signature.return_type.to_string()
153 }
154}
155
156thread_local! {
157 static CURRENT_STATE: RefCell<Option<Arc<State>>> = RefCell::new(None);
159}
160
161pub struct CustomFunctionRegistry {
163 functions: HashMap<String, Rc<CustomFunctionDefinition>>,
164}
165
166impl CustomFunctionRegistry {
167 thread_local!(
168 static INSTANCE: RefCell<CustomFunctionRegistry> = RefCell::new(CustomFunctionRegistry::new())
169 );
170
171 fn new() -> Self {
172 Self { functions: HashMap::new() }
173 }
174
175 pub fn register_function(
177 name: String,
178 signature: FunctionSignature,
179 executor: CustomFunctionExecutor,
180 ) -> Result<(), String> {
181 Self::INSTANCE.with(|registry| {
182 let mut reg = registry.borrow_mut();
183 if reg.functions.contains_key(&name) {
184 return Err(format!("函数 '{}' 已经存在", name));
185 }
186
187 let definition = CustomFunctionDefinition::new(
188 name.clone(),
189 signature,
190 executor,
191 );
192 reg.functions.insert(name, Rc::new(definition));
193 Ok(())
194 })
195 }
196
197 pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
199 Self::INSTANCE.with(|registry| {
200 registry
201 .borrow()
202 .functions
203 .get(name)
204 .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
205 })
206 }
207
208 pub fn is_registered(name: &str) -> bool {
210 Self::INSTANCE
211 .with(|registry| registry.borrow().functions.contains_key(name))
212 }
213
214 pub fn set_current_state(state: Option<Arc<State>>) {
216 CURRENT_STATE.with(|s| {
217 *s.borrow_mut() = state;
218 });
219 }
220
221 pub fn has_current_state() -> bool {
223 CURRENT_STATE.with(|s| s.borrow().is_some())
224 }
225
226 pub fn list_functions() -> Vec<String> {
228 Self::INSTANCE.with(|registry| {
229 registry.borrow().functions.keys().cloned().collect()
230 })
231 }
232
233 pub fn clear() {
235 Self::INSTANCE.with(|registry| {
236 registry.borrow_mut().functions.clear();
237 });
238 }
239}
240
241impl From<&CustomFunction> for Rc<dyn FunctionDefinition> {
242 fn from(custom: &CustomFunction) -> Self {
243 CustomFunctionRegistry::get_definition(&custom.name).unwrap_or_else(
244 || {
245 Rc::new(StaticFunction {
247 signature: FunctionSignature {
248 parameters: vec![],
249 return_type: VariableType::Null,
250 },
251 implementation: Rc::new(|_| {
252 Err(anyhow::anyhow!("自定义函数未找到"))
253 }),
254 })
255 },
256 )
257 }
258}