mf_expression/functions/
custom.rs1use crate::functions::defs::{
6 FunctionDefinition, FunctionSignature, StaticFunction,
7};
8use crate::functions::arguments::Arguments;
9use crate::variable::{Variable, VariableType};
10use std::rc::Rc;
11use std::sync::Arc;
12use std::collections::HashMap;
13use std::cell::RefCell;
14use std::fmt::Display;
15use anyhow::Result as AnyhowResult;
16use std::any::Any;
17use std::marker::PhantomData;
18
19#[derive(Debug, PartialEq, Eq, Clone, Hash)]
21pub struct CustomFunction {
22 pub name: String,
24}
25
26impl CustomFunction {
27 pub fn new(name: String) -> Self {
28 Self { name }
29 }
30}
31
32impl Display for CustomFunction {
33 fn fmt(
34 &self,
35 f: &mut std::fmt::Formatter<'_>,
36 ) -> std::fmt::Result {
37 write!(f, "{}", self.name)
38 }
39}
40
41impl TryFrom<&str> for CustomFunction {
42 type Error = strum::ParseError;
43
44 fn try_from(value: &str) -> Result<Self, Self::Error> {
45 if CustomFunctionRegistry::is_registered(value) {
47 Ok(CustomFunction::new(value.to_string()))
48 } else {
49 Err(strum::ParseError::VariantNotFound)
50 }
51 }
52}
53
54type ErasedExecutor = Box<
56 dyn Fn(
57 &Arguments,
58 Option<&Arc<dyn Any + Send + Sync>>,
59 ) -> AnyhowResult<Variable>
60 + 'static,
61>;
62
63pub struct CustomFunctionDefinition {
65 pub name: String,
67 pub signature: FunctionSignature,
69 pub executor: ErasedExecutor,
71}
72
73impl CustomFunctionDefinition {
74 pub fn new(
75 name: String,
76 signature: FunctionSignature,
77 executor: ErasedExecutor,
78 ) -> Self {
79 Self { name, signature, executor }
80 }
81}
82
83impl FunctionDefinition for CustomFunctionDefinition {
84 fn call(
85 &self,
86 args: Arguments,
87 ) -> AnyhowResult<Variable> {
88 let state = CURRENT_STATE.with(|s| s.borrow().clone());
90 (self.executor)(&args, state.as_ref())
91 }
92
93 fn required_parameters(&self) -> usize {
94 self.signature.parameters.len()
95 }
96
97 fn optional_parameters(&self) -> usize {
98 0 }
100
101 fn check_types(
102 &self,
103 args: &[Rc<VariableType>],
104 ) -> crate::functions::defs::FunctionTypecheck {
105 let mut typecheck =
106 crate::functions::defs::FunctionTypecheck::default();
107 typecheck.return_type = self.signature.return_type.clone();
108
109 if args.len() != self.required_parameters() {
110 typecheck.general = Some(format!(
111 "期望 `{}` 参数, 实际 `{}` 参数.",
112 self.required_parameters(),
113 args.len()
114 ));
115 }
116
117 for (i, (arg, expected_type)) in
119 args.iter().zip(self.signature.parameters.iter()).enumerate()
120 {
121 if !arg.satisfies(expected_type) {
122 typecheck.arguments.push((
123 i,
124 format!(
125 "参数类型 `{arg}` 不能赋值给参数类型 `{expected_type}`.",
126 ),
127 ));
128 }
129 }
130
131 typecheck
132 }
133
134 fn param_type(
135 &self,
136 index: usize,
137 ) -> Option<VariableType> {
138 self.signature.parameters.get(index).cloned()
139 }
140
141 fn param_type_str(
142 &self,
143 index: usize,
144 ) -> String {
145 self.signature
146 .parameters
147 .get(index)
148 .map(|x| x.to_string())
149 .unwrap_or_else(|| "never".to_string())
150 }
151
152 fn return_type(&self) -> VariableType {
153 self.signature.return_type.clone()
154 }
155
156 fn return_type_str(&self) -> String {
157 self.signature.return_type.to_string()
158 }
159}
160
161thread_local! {
162 static CURRENT_STATE: RefCell<Option<Arc<dyn Any + Send + Sync>>> = RefCell::new(None);
164}
165
166pub struct CustomFunctionRegistry {
168 functions: HashMap<String, Rc<CustomFunctionDefinition>>,
169}
170
171impl CustomFunctionRegistry {
172 thread_local!(
173 static INSTANCE: RefCell<CustomFunctionRegistry> = RefCell::new(CustomFunctionRegistry::new())
174 );
175
176 fn new() -> Self {
177 Self { functions: HashMap::new() }
178 }
179
180 fn register_function_erased(
182 name: String,
183 signature: FunctionSignature,
184 executor: ErasedExecutor,
185 ) -> Result<(), String> {
186 Self::INSTANCE.with(|registry| {
187 let mut reg = registry.borrow_mut();
188 if reg.functions.contains_key(&name) {
189 return Err(format!("函数 '{}' 已经存在", name));
190 }
191
192 let definition = CustomFunctionDefinition::new(
193 name.clone(),
194 signature,
195 executor,
196 );
197 reg.functions.insert(name, Rc::new(definition));
198 Ok(())
199 })
200 }
201
202 pub fn get_definition(name: &str) -> Option<Rc<dyn FunctionDefinition>> {
204 Self::INSTANCE.with(|registry| {
205 registry
206 .borrow()
207 .functions
208 .get(name)
209 .map(|def| def.clone() as Rc<dyn FunctionDefinition>)
210 })
211 }
212
213 pub fn is_registered(name: &str) -> bool {
215 Self::INSTANCE
216 .with(|registry| registry.borrow().functions.contains_key(name))
217 }
218
219 pub fn set_current_state<S: Send + Sync + 'static>(state: Option<Arc<S>>) {
221 CURRENT_STATE.with(|s| {
222 *s.borrow_mut() = state.map(|st| st as Arc<dyn Any + Send + Sync>);
223 });
224 }
225
226 pub fn has_current_state() -> bool {
228 CURRENT_STATE.with(|s| s.borrow().is_some())
229 }
230
231 pub fn clear_current_state() {
233 CURRENT_STATE.with(|s| {
234 *s.borrow_mut() = None;
235 });
236 }
237
238 pub fn list_functions() -> Vec<String> {
240 Self::INSTANCE.with(|registry| {
241 registry.borrow().functions.keys().cloned().collect()
242 })
243 }
244
245 pub fn clear() {
247 Self::INSTANCE.with(|registry| {
248 registry.borrow_mut().functions.clear();
249 });
250 }
251}
252
253pub struct CustomFunctionHelper<S> {
255 _marker: PhantomData<S>,
256}
257
258impl<S: Send + Sync + 'static> CustomFunctionHelper<S> {
259 pub fn new() -> Self {
261 Self { _marker: PhantomData }
262 }
263
264 pub fn register_function(
272 &self,
273 name: String,
274 params: Vec<VariableType>,
275 return_type: VariableType,
276 executor: Box<
277 dyn Fn(&Arguments, Option<&S>) -> AnyhowResult<Variable> + 'static,
278 >,
279 ) -> Result<(), String> {
280 let signature = FunctionSignature { parameters: params, return_type };
281
282 let wrapped_executor: ErasedExecutor =
283 Box::new(move |args, state_any| {
284 let typed_state = state_any.and_then(|s| s.downcast_ref::<S>());
285 executor(args, typed_state)
286 });
287
288 CustomFunctionRegistry::register_function_erased(
289 name,
290 signature,
291 wrapped_executor,
292 )
293 }
294}
295
296impl<S: Send + Sync + 'static> Default for CustomFunctionHelper<S> {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302impl From<&CustomFunction> for Rc<dyn FunctionDefinition> {
303 fn from(custom: &CustomFunction) -> Self {
304 CustomFunctionRegistry::get_definition(&custom.name).unwrap_or_else(
305 || {
306 Rc::new(StaticFunction {
308 signature: FunctionSignature {
309 parameters: vec![],
310 return_type: VariableType::Null,
311 },
312 implementation: Rc::new(|_| {
313 Err(anyhow::anyhow!("自定义函数未找到"))
314 }),
315 })
316 },
317 )
318 }
319}