intuicio_core/
function.rs

1use crate::{
2    Visibility,
3    context::Context,
4    meta::Meta,
5    registry::Registry,
6    types::{Type, TypeHandle, TypeQuery},
7};
8use intuicio_data::data_stack::DataStackPack;
9use rustc_hash::FxHasher;
10use std::{
11    borrow::Cow,
12    hash::{Hash, Hasher},
13    sync::Arc,
14};
15
16pub type FunctionHandle = Arc<Function>;
17pub type FunctionMetaQuery = fn(&Meta) -> bool;
18
19pub enum FunctionBody {
20    Pointer(fn(&mut Context, &Registry)),
21    #[allow(clippy::type_complexity)]
22    Closure(Arc<dyn Fn(&mut Context, &Registry) + Send + Sync>),
23}
24
25impl FunctionBody {
26    pub fn pointer(pointer: fn(&mut Context, &Registry)) -> Self {
27        Self::Pointer(pointer)
28    }
29
30    pub fn closure<T>(closure: T) -> Self
31    where
32        T: Fn(&mut Context, &Registry) + Send + Sync + 'static,
33    {
34        Self::Closure(Arc::new(closure))
35    }
36
37    pub fn invoke(&self, context: &mut Context, registry: &Registry) {
38        match self {
39            Self::Pointer(pointer) => pointer(context, registry),
40            Self::Closure(closure) => closure(context, registry),
41        }
42    }
43}
44
45impl std::fmt::Debug for FunctionBody {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Self::Pointer(_) => write!(f, "<Pointer>"),
49            Self::Closure(_) => write!(f, "<Closure>"),
50        }
51    }
52}
53
54#[derive(Clone, PartialEq)]
55pub struct FunctionParameter {
56    pub meta: Option<Meta>,
57    pub name: String,
58    pub type_handle: TypeHandle,
59}
60
61impl FunctionParameter {
62    pub fn new(name: impl ToString, type_handle: TypeHandle) -> Self {
63        Self {
64            meta: None,
65            name: name.to_string(),
66            type_handle,
67        }
68    }
69
70    pub fn with_meta(mut self, meta: Meta) -> Self {
71        self.meta = Some(meta);
72        self
73    }
74}
75
76impl std::fmt::Debug for FunctionParameter {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("FunctionParameter")
79            .field("meta", &self.meta)
80            .field("name", &self.name)
81            .field("type_handle", &self.type_handle.name())
82            .finish()
83    }
84}
85
86#[derive(Clone, PartialEq)]
87pub struct FunctionSignature {
88    pub meta: Option<Meta>,
89    pub name: String,
90    pub module_name: Option<String>,
91    pub type_handle: Option<TypeHandle>,
92    pub visibility: Visibility,
93    pub inputs: Vec<FunctionParameter>,
94    pub outputs: Vec<FunctionParameter>,
95}
96
97impl FunctionSignature {
98    pub fn new(name: impl ToString) -> Self {
99        Self {
100            meta: None,
101            name: name.to_string(),
102            module_name: None,
103            type_handle: None,
104            visibility: Visibility::default(),
105            inputs: vec![],
106            outputs: vec![],
107        }
108    }
109
110    pub fn with_meta(mut self, meta: Meta) -> Self {
111        self.meta = Some(meta);
112        self
113    }
114
115    pub fn with_module_name(mut self, name: impl ToString) -> Self {
116        self.module_name = Some(name.to_string());
117        self
118    }
119
120    pub fn with_type_handle(mut self, handle: TypeHandle) -> Self {
121        self.type_handle = Some(handle);
122        self
123    }
124
125    pub fn with_visibility(mut self, visibility: Visibility) -> Self {
126        self.visibility = visibility;
127        self
128    }
129
130    pub fn with_input(mut self, parameter: FunctionParameter) -> Self {
131        self.inputs.push(parameter);
132        self
133    }
134
135    pub fn with_output(mut self, parameter: FunctionParameter) -> Self {
136        self.outputs.push(parameter);
137        self
138    }
139}
140
141impl std::fmt::Debug for FunctionSignature {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("FunctionSignature")
144            .field("meta", &self.meta)
145            .field("name", &self.name)
146            .field("module_name", &self.module_name)
147            .field(
148                "type_handle",
149                &match self.type_handle.as_ref() {
150                    Some(type_handle) => type_handle.name().to_owned(),
151                    None => "!".to_owned(),
152                },
153            )
154            .field("visibility", &self.visibility)
155            .field("inputs", &self.inputs)
156            .field("outputs", &self.outputs)
157            .finish()
158    }
159}
160
161impl std::fmt::Display for FunctionSignature {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        if let Some(meta) = self.meta.as_ref() {
164            write!(f, "#{} ", meta)?;
165        }
166        if let Some(module_name) = self.module_name.as_ref() {
167            write!(f, "mod {} ", module_name)?;
168        }
169        if let Some(type_handle) = self.type_handle.as_ref() {
170            match &**type_handle {
171                Type::Struct(value) => {
172                    write!(f, "struct {} ", value.type_name())?;
173                }
174                Type::Enum(value) => {
175                    write!(f, "enum {} ", value.type_name())?;
176                }
177            }
178        }
179        write!(f, "fn {}(", self.name)?;
180        for (index, parameter) in self.inputs.iter().enumerate() {
181            if index > 0 {
182                write!(f, ", ")?;
183            }
184            write!(
185                f,
186                "{}: {}",
187                parameter.name,
188                parameter.type_handle.type_name()
189            )?;
190        }
191        write!(f, ") -> (")?;
192        for (index, parameter) in self.outputs.iter().enumerate() {
193            if index > 0 {
194                write!(f, ", ")?;
195            }
196            write!(
197                f,
198                "{}: {}",
199                parameter.name,
200                parameter.type_handle.type_name()
201            )?;
202        }
203        write!(f, ")")
204    }
205}
206
207#[derive(Debug)]
208pub struct Function {
209    signature: FunctionSignature,
210    body: FunctionBody,
211}
212
213impl Function {
214    pub fn new(signature: FunctionSignature, body: FunctionBody) -> Self {
215        Self { signature, body }
216    }
217
218    pub fn signature(&self) -> &FunctionSignature {
219        &self.signature
220    }
221
222    pub fn invoke(&self, context: &mut Context, registry: &Registry) {
223        context.store_registers();
224        self.body.invoke(context, registry);
225        context.restore_registers();
226    }
227
228    pub fn call<O: DataStackPack, I: DataStackPack>(
229        &self,
230        context: &mut Context,
231        registry: &Registry,
232        inputs: I,
233        verify: bool,
234    ) -> O {
235        if verify {
236            self.verify_inputs_outputs::<O, I>();
237        }
238        inputs.stack_push_reversed(context.stack());
239        self.invoke(context, registry);
240        O::stack_pop(context.stack())
241    }
242
243    pub fn verify_inputs_outputs<O: DataStackPack, I: DataStackPack>(&self) {
244        let input_types = I::pack_types();
245        if input_types.len() != self.signature.inputs.len() {
246            panic!("Function: {} got wrong inputs number!", self.signature.name);
247        }
248        let output_types = O::pack_types();
249        if output_types.len() != self.signature.outputs.len() {
250            panic!(
251                "Function: {} got wrong outputs number!",
252                self.signature.name
253            );
254        }
255        for (parameter, type_hash) in self.signature.inputs.iter().zip(input_types) {
256            if parameter.type_handle.type_hash() != type_hash {
257                panic!(
258                    "Function: {} input parameter: {} got wrong value type!",
259                    self.signature.name, parameter.name
260                );
261            }
262        }
263        for (parameter, type_hash) in self.signature.outputs.iter().zip(output_types) {
264            if parameter.type_handle.type_hash() != type_hash {
265                panic!(
266                    "Function: {} output parameter: {} got wrong value type!",
267                    self.signature.name, parameter.name
268                );
269            }
270        }
271    }
272
273    pub fn into_handle(self) -> FunctionHandle {
274        self.into()
275    }
276}
277
278#[derive(Debug, Default, Clone, PartialEq, Hash)]
279pub struct FunctionQueryParameter<'a> {
280    pub name: Option<Cow<'a, str>>,
281    pub type_query: Option<TypeQuery<'a>>,
282    pub meta: Option<FunctionMetaQuery>,
283}
284
285impl FunctionQueryParameter<'_> {
286    pub fn is_valid(&self, parameter: &FunctionParameter) -> bool {
287        self.name
288            .as_ref()
289            .map(|name| name.as_ref() == parameter.name)
290            .unwrap_or(true)
291            && self
292                .type_query
293                .as_ref()
294                .map(|query| query.is_valid(&parameter.type_handle))
295                .unwrap_or(true)
296            && self
297                .meta
298                .as_ref()
299                .map(|query| parameter.meta.as_ref().map(query).unwrap_or(false))
300                .unwrap_or(true)
301    }
302
303    pub fn to_static(&self) -> FunctionQueryParameter<'static> {
304        FunctionQueryParameter {
305            name: self
306                .name
307                .as_ref()
308                .map(|name| name.as_ref().to_owned().into()),
309            type_query: self.type_query.as_ref().map(|query| query.to_static()),
310            meta: self.meta,
311        }
312    }
313}
314
315#[derive(Debug, Default, Clone, PartialEq, Hash)]
316pub struct FunctionQuery<'a> {
317    pub name: Option<Cow<'a, str>>,
318    pub module_name: Option<Cow<'a, str>>,
319    pub type_query: Option<TypeQuery<'a>>,
320    pub visibility: Option<Visibility>,
321    pub inputs: Cow<'a, [FunctionQueryParameter<'a>]>,
322    pub outputs: Cow<'a, [FunctionQueryParameter<'a>]>,
323    pub meta: Option<FunctionMetaQuery>,
324}
325
326impl FunctionQuery<'_> {
327    pub fn is_valid(&self, signature: &FunctionSignature) -> bool {
328        self.name
329            .as_ref()
330            .map(|name| name.as_ref() == signature.name)
331            .unwrap_or(true)
332            && self
333                .module_name
334                .as_ref()
335                .map(|name| {
336                    signature
337                        .module_name
338                        .as_ref()
339                        .map(|module_name| name.as_ref() == module_name)
340                        .unwrap_or(false)
341                })
342                .unwrap_or(true)
343            && self
344                .type_query
345                .as_ref()
346                .map(|query| {
347                    signature
348                        .type_handle
349                        .as_ref()
350                        .map(|handle| query.is_valid(handle))
351                        .unwrap_or(false)
352                })
353                .unwrap_or(true)
354            && self
355                .visibility
356                .map(|visibility| signature.visibility.is_visible(visibility))
357                .unwrap_or(true)
358            && self
359                .inputs
360                .iter()
361                .zip(signature.inputs.iter())
362                .all(|(query, parameter)| query.is_valid(parameter))
363            && self
364                .outputs
365                .iter()
366                .zip(signature.outputs.iter())
367                .all(|(query, parameter)| query.is_valid(parameter))
368            && self
369                .meta
370                .as_ref()
371                .map(|query| signature.meta.as_ref().map(query).unwrap_or(false))
372                .unwrap_or(true)
373    }
374
375    pub fn as_hash(&self) -> u64 {
376        let mut hasher = FxHasher::default();
377        self.hash(&mut hasher);
378        hasher.finish()
379    }
380
381    pub fn to_static(&self) -> FunctionQuery<'static> {
382        FunctionQuery {
383            name: self
384                .name
385                .as_ref()
386                .map(|name| name.as_ref().to_owned().into()),
387            module_name: self
388                .module_name
389                .as_ref()
390                .map(|name| name.as_ref().to_owned().into()),
391            type_query: self.type_query.as_ref().map(|query| query.to_static()),
392            visibility: self.visibility,
393            inputs: self
394                .inputs
395                .as_ref()
396                .iter()
397                .map(|query| query.to_static())
398                .collect(),
399            outputs: self
400                .outputs
401                .as_ref()
402                .iter()
403                .map(|query| query.to_static())
404                .collect(),
405            meta: self.meta,
406        }
407    }
408}
409
410#[macro_export]
411macro_rules! function_signature {
412    (
413        $registry:expr
414        =>
415        $(mod $module_name:ident)?
416        $(type ($type:ty))?
417        fn
418        $name:ident
419        ($( $input_name:ident : $input_type:ty ),*)
420        ->
421        ($( $output_name:ident : $output_type:ty ),*)
422    ) => {{
423        let mut result = $crate::function::FunctionSignature::new(stringify!($name));
424        $(
425            result.module_name = Some(stringify!($module_name).to_owned());
426        )?
427        $(
428            result.type_handle = Some($registry.find_type($crate::types::TypeQuery::of::<$type>()).unwrap());
429        )?
430        $(
431            result.inputs.push(
432                $crate::function::FunctionParameter::new(
433                    stringify!($input_name).to_owned(),
434                    $registry.find_type($crate::types::TypeQuery::of::<$input_type>()).unwrap()
435                )
436            );
437        )*
438        $(
439            result.outputs.push(
440                $crate::function::FunctionParameter::new(
441                    stringify!($output_name).to_owned(),
442                    $registry.find_type($crate::types::TypeQuery::of::<$output_type>()).unwrap()
443                )
444            );
445        )*
446        result
447    }};
448}
449
450#[macro_export]
451macro_rules! define_function {
452    (
453        $registry:expr
454        =>
455        $(mod $module_name:ident)?
456        $(type ($type:ty))?
457        fn
458        $name:ident
459        ($( $input_name:ident : $input_type:ty),*)
460        ->
461        ($( $output_name:ident : $output_type:ty),*)
462        $code:block
463    ) => {
464        $crate::function::Function::new(
465            $crate::function_signature! {
466                $registry
467                =>
468                $(mod $module_name)?
469                $(type ($type))?
470                fn
471                $name
472                ($($input_name : $input_type),*)
473                ->
474                ($($output_name : $output_type),*)
475            },
476            $crate::function::FunctionBody::closure(move |context, registry| {
477                use intuicio_data::data_stack::DataStackPack;
478                #[allow(unused_mut)]
479                let ($(mut $input_name,)*) = <($($input_type,)*)>::stack_pop(context.stack());
480                $code.stack_push_reversed(context.stack());
481            }),
482        )
483    };
484}
485
486#[cfg(test)]
487mod tests {
488    use crate as intuicio_core;
489    use crate::{context::*, function::*, registry::*, types::struct_type::*};
490    use intuicio_data;
491    use intuicio_derive::*;
492
493    #[intuicio_function(meta = "foo", args_meta(_bar = "foo"))]
494    fn function_meta(_bar: bool) {}
495
496    #[test]
497    fn test_function() {
498        fn add(context: &mut Context, _: &Registry) {
499            let a = context.stack().pop::<i32>().unwrap();
500            let b = context.stack().pop::<i32>().unwrap();
501            context.stack().push(a + b);
502        }
503
504        let i32_handle = NativeStructBuilder::new::<i32>()
505            .build()
506            .into_type()
507            .into_handle();
508        let signature = FunctionSignature::new("add")
509            .with_input(FunctionParameter::new("a", i32_handle.clone()))
510            .with_input(FunctionParameter::new("b", i32_handle.clone()))
511            .with_output(FunctionParameter::new("result", i32_handle));
512        let function = Function::new(signature.to_owned(), FunctionBody::pointer(add));
513
514        assert!(FunctionQuery::default().is_valid(&signature));
515        assert!(
516            FunctionQuery {
517                name: Some("add".into()),
518                ..Default::default()
519            }
520            .is_valid(&signature)
521        );
522        assert!(
523            FunctionQuery {
524                name: Some("add".into()),
525                inputs: [
526                    FunctionQueryParameter {
527                        name: Some("a".into()),
528                        ..Default::default()
529                    },
530                    FunctionQueryParameter {
531                        name: Some("b".into()),
532                        ..Default::default()
533                    }
534                ]
535                .as_slice()
536                .into(),
537                outputs: [FunctionQueryParameter {
538                    name: Some("result".into()),
539                    ..Default::default()
540                }]
541                .as_slice()
542                .into(),
543                ..Default::default()
544            }
545            .is_valid(&signature)
546        );
547        assert!(
548            !FunctionQuery {
549                name: Some("add".into()),
550                inputs: [
551                    FunctionQueryParameter {
552                        name: Some("b".into()),
553                        ..Default::default()
554                    },
555                    FunctionQueryParameter {
556                        name: Some("a".into()),
557                        ..Default::default()
558                    }
559                ]
560                .as_slice()
561                .into(),
562                ..Default::default()
563            }
564            .is_valid(&signature)
565        );
566
567        let mut context = Context::new(10240, 10240);
568        let registry = Registry::default().with_basic_types();
569
570        context.stack().push(2);
571        context.stack().push(40);
572        function.invoke(&mut context, &registry);
573        assert_eq!(context.stack().pop::<i32>().unwrap(), 42);
574
575        assert_eq!(
576            function_meta::define_signature(&registry).meta,
577            Some(Meta::Identifier("foo".to_owned()))
578        );
579    }
580}