intuicio_core/
function.rs

1use crate::{
2    Visibility,
3    context::Context,
4    meta::Meta,
5    registry::Registry,
6    types::{Type, TypeHandle, TypeQuery},
7};
8use intuicio_data::data_stack::DataStackPack;
9use rustc_hash::FxHasher;
10use std::{
11    borrow::Cow,
12    hash::{Hash, Hasher},
13    sync::Arc,
14};
15
16pub type FunctionHandle = Arc<Function>;
17pub type FunctionMetaQuery = fn(&Meta) -> bool;
18
19pub enum FunctionBody {
20    Pointer(fn(&mut Context, &Registry)),
21    #[allow(clippy::type_complexity)]
22    Closure(Arc<dyn Fn(&mut Context, &Registry) + Send + Sync>),
23}
24
25impl FunctionBody {
26    pub fn pointer(pointer: fn(&mut Context, &Registry)) -> Self {
27        Self::Pointer(pointer)
28    }
29
30    pub fn closure<T>(closure: T) -> Self
31    where
32        T: Fn(&mut Context, &Registry) + Send + Sync + 'static,
33    {
34        Self::Closure(Arc::new(closure))
35    }
36
37    pub fn invoke(&self, context: &mut Context, registry: &Registry) {
38        match self {
39            Self::Pointer(pointer) => pointer(context, registry),
40            Self::Closure(closure) => closure(context, registry),
41        }
42    }
43}
44
45impl std::fmt::Debug for FunctionBody {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Self::Pointer(_) => write!(f, "<Pointer>"),
49            Self::Closure(_) => write!(f, "<Closure>"),
50        }
51    }
52}
53
54#[derive(Clone, PartialEq)]
55pub struct FunctionParameter {
56    pub meta: Option<Meta>,
57    pub name: String,
58    pub type_handle: TypeHandle,
59}
60
61impl FunctionParameter {
62    pub fn new(name: impl ToString, type_handle: TypeHandle) -> Self {
63        Self {
64            meta: None,
65            name: name.to_string(),
66            type_handle,
67        }
68    }
69}
70
71impl std::fmt::Debug for FunctionParameter {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("FunctionParameter")
74            .field("meta", &self.meta)
75            .field("name", &self.name)
76            .field("type_handle", &self.type_handle.name())
77            .finish()
78    }
79}
80
81#[derive(Clone, PartialEq)]
82pub struct FunctionSignature {
83    pub meta: Option<Meta>,
84    pub name: String,
85    pub module_name: Option<String>,
86    pub type_handle: Option<TypeHandle>,
87    pub visibility: Visibility,
88    pub inputs: Vec<FunctionParameter>,
89    pub outputs: Vec<FunctionParameter>,
90}
91
92impl FunctionSignature {
93    pub fn new(name: impl ToString) -> Self {
94        Self {
95            meta: None,
96            name: name.to_string(),
97            module_name: None,
98            type_handle: None,
99            visibility: Visibility::default(),
100            inputs: vec![],
101            outputs: vec![],
102        }
103    }
104
105    pub fn with_meta(mut self, meta: Meta) -> Self {
106        self.meta = Some(meta);
107        self
108    }
109
110    pub fn with_module_name(mut self, name: impl ToString) -> Self {
111        self.module_name = Some(name.to_string());
112        self
113    }
114
115    pub fn with_type_handle(mut self, handle: TypeHandle) -> Self {
116        self.type_handle = Some(handle);
117        self
118    }
119
120    pub fn with_visibility(mut self, visibility: Visibility) -> Self {
121        self.visibility = visibility;
122        self
123    }
124
125    pub fn with_input(mut self, parameter: FunctionParameter) -> Self {
126        self.inputs.push(parameter);
127        self
128    }
129
130    pub fn with_output(mut self, parameter: FunctionParameter) -> Self {
131        self.outputs.push(parameter);
132        self
133    }
134}
135
136impl std::fmt::Debug for FunctionSignature {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("FunctionSignature")
139            .field("meta", &self.meta)
140            .field("name", &self.name)
141            .field("module_name", &self.module_name)
142            .field(
143                "type_handle",
144                &match self.type_handle.as_ref() {
145                    Some(type_handle) => type_handle.name().to_owned(),
146                    None => "!".to_owned(),
147                },
148            )
149            .field("visibility", &self.visibility)
150            .field("inputs", &self.inputs)
151            .field("outputs", &self.outputs)
152            .finish()
153    }
154}
155
156impl std::fmt::Display for FunctionSignature {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        if let Some(meta) = self.meta.as_ref() {
159            write!(f, "#{} ", meta)?;
160        }
161        if let Some(module_name) = self.module_name.as_ref() {
162            write!(f, "mod {} ", module_name)?;
163        }
164        if let Some(type_handle) = self.type_handle.as_ref() {
165            match &**type_handle {
166                Type::Struct(value) => {
167                    write!(f, "struct {} ", value.type_name())?;
168                }
169                Type::Enum(value) => {
170                    write!(f, "enum {} ", value.type_name())?;
171                }
172            }
173        }
174        write!(f, "fn {}(", self.name)?;
175        for (index, parameter) in self.inputs.iter().enumerate() {
176            if index > 0 {
177                write!(f, ", ")?;
178            }
179            write!(
180                f,
181                "{}: {}",
182                parameter.name,
183                parameter.type_handle.type_name()
184            )?;
185        }
186        write!(f, ") -> (")?;
187        for (index, parameter) in self.outputs.iter().enumerate() {
188            if index > 0 {
189                write!(f, ", ")?;
190            }
191            write!(
192                f,
193                "{}: {}",
194                parameter.name,
195                parameter.type_handle.type_name()
196            )?;
197        }
198        write!(f, ")")
199    }
200}
201
202#[derive(Debug)]
203pub struct Function {
204    signature: FunctionSignature,
205    body: FunctionBody,
206}
207
208impl Function {
209    pub fn new(signature: FunctionSignature, body: FunctionBody) -> Self {
210        Self { signature, body }
211    }
212
213    pub fn signature(&self) -> &FunctionSignature {
214        &self.signature
215    }
216
217    pub fn invoke(&self, context: &mut Context, registry: &Registry) {
218        context.store_registers();
219        self.body.invoke(context, registry);
220        context.restore_registers();
221    }
222
223    pub fn call<O: DataStackPack, I: DataStackPack>(
224        &self,
225        context: &mut Context,
226        registry: &Registry,
227        inputs: I,
228        verify: bool,
229    ) -> O {
230        if verify {
231            self.verify_inputs_outputs::<O, I>();
232        }
233        inputs.stack_push_reversed(context.stack());
234        self.invoke(context, registry);
235        O::stack_pop(context.stack())
236    }
237
238    pub fn verify_inputs_outputs<O: DataStackPack, I: DataStackPack>(&self) {
239        let input_types = I::pack_types();
240        if input_types.len() != self.signature.inputs.len() {
241            panic!("Function: {} got wrong inputs number!", self.signature.name);
242        }
243        let output_types = O::pack_types();
244        if output_types.len() != self.signature.outputs.len() {
245            panic!(
246                "Function: {} got wrong outputs number!",
247                self.signature.name
248            );
249        }
250        for (parameter, type_hash) in self.signature.inputs.iter().zip(input_types) {
251            if parameter.type_handle.type_hash() != type_hash {
252                panic!(
253                    "Function: {} input parameter: {} got wrong value type!",
254                    self.signature.name, parameter.name
255                );
256            }
257        }
258        for (parameter, type_hash) in self.signature.outputs.iter().zip(output_types) {
259            if parameter.type_handle.type_hash() != type_hash {
260                panic!(
261                    "Function: {} output parameter: {} got wrong value type!",
262                    self.signature.name, parameter.name
263                );
264            }
265        }
266    }
267
268    pub fn into_handle(self) -> FunctionHandle {
269        self.into()
270    }
271}
272
273#[derive(Debug, Default, Clone, PartialEq, Hash)]
274pub struct FunctionQueryParameter<'a> {
275    pub name: Option<Cow<'a, str>>,
276    pub type_query: Option<TypeQuery<'a>>,
277    pub meta: Option<FunctionMetaQuery>,
278}
279
280impl FunctionQueryParameter<'_> {
281    pub fn is_valid(&self, parameter: &FunctionParameter) -> bool {
282        self.name
283            .as_ref()
284            .map(|name| name.as_ref() == parameter.name)
285            .unwrap_or(true)
286            && self
287                .type_query
288                .as_ref()
289                .map(|query| query.is_valid(&parameter.type_handle))
290                .unwrap_or(true)
291            && self
292                .meta
293                .as_ref()
294                .map(|query| parameter.meta.as_ref().map(query).unwrap_or(false))
295                .unwrap_or(true)
296    }
297
298    pub fn to_static(&self) -> FunctionQueryParameter<'static> {
299        FunctionQueryParameter {
300            name: self
301                .name
302                .as_ref()
303                .map(|name| name.as_ref().to_owned().into()),
304            type_query: self.type_query.as_ref().map(|query| query.to_static()),
305            meta: self.meta,
306        }
307    }
308}
309
310#[derive(Debug, Default, Clone, PartialEq, Hash)]
311pub struct FunctionQuery<'a> {
312    pub name: Option<Cow<'a, str>>,
313    pub module_name: Option<Cow<'a, str>>,
314    pub type_query: Option<TypeQuery<'a>>,
315    pub visibility: Option<Visibility>,
316    pub inputs: Cow<'a, [FunctionQueryParameter<'a>]>,
317    pub outputs: Cow<'a, [FunctionQueryParameter<'a>]>,
318    pub meta: Option<FunctionMetaQuery>,
319}
320
321impl FunctionQuery<'_> {
322    pub fn is_valid(&self, signature: &FunctionSignature) -> bool {
323        self.name
324            .as_ref()
325            .map(|name| name.as_ref() == signature.name)
326            .unwrap_or(true)
327            && self
328                .module_name
329                .as_ref()
330                .map(|name| {
331                    signature
332                        .module_name
333                        .as_ref()
334                        .map(|module_name| name.as_ref() == module_name)
335                        .unwrap_or(false)
336                })
337                .unwrap_or(true)
338            && self
339                .type_query
340                .as_ref()
341                .map(|query| {
342                    signature
343                        .type_handle
344                        .as_ref()
345                        .map(|handle| query.is_valid(handle))
346                        .unwrap_or(false)
347                })
348                .unwrap_or(true)
349            && self
350                .visibility
351                .map(|visibility| signature.visibility.is_visible(visibility))
352                .unwrap_or(true)
353            && self
354                .inputs
355                .iter()
356                .zip(signature.inputs.iter())
357                .all(|(query, parameter)| query.is_valid(parameter))
358            && self
359                .outputs
360                .iter()
361                .zip(signature.outputs.iter())
362                .all(|(query, parameter)| query.is_valid(parameter))
363            && self
364                .meta
365                .as_ref()
366                .map(|query| signature.meta.as_ref().map(query).unwrap_or(false))
367                .unwrap_or(true)
368    }
369
370    pub fn as_hash(&self) -> u64 {
371        let mut hasher = FxHasher::default();
372        self.hash(&mut hasher);
373        hasher.finish()
374    }
375
376    pub fn to_static(&self) -> FunctionQuery<'static> {
377        FunctionQuery {
378            name: self
379                .name
380                .as_ref()
381                .map(|name| name.as_ref().to_owned().into()),
382            module_name: self
383                .module_name
384                .as_ref()
385                .map(|name| name.as_ref().to_owned().into()),
386            type_query: self.type_query.as_ref().map(|query| query.to_static()),
387            visibility: self.visibility,
388            inputs: self
389                .inputs
390                .as_ref()
391                .iter()
392                .map(|query| query.to_static())
393                .collect(),
394            outputs: self
395                .outputs
396                .as_ref()
397                .iter()
398                .map(|query| query.to_static())
399                .collect(),
400            meta: self.meta,
401        }
402    }
403}
404
405#[macro_export]
406macro_rules! function_signature {
407    (
408        $registry:expr
409        =>
410        $(mod $module_name:ident)?
411        $(type ($type:ty))?
412        fn
413        $name:ident
414        ($( $input_name:ident : $input_type:ty ),*)
415        ->
416        ($( $output_name:ident : $output_type:ty ),*)
417    ) => {{
418        let mut result = $crate::function::FunctionSignature::new(stringify!($name));
419        $(
420            result.module_name = Some(stringify!($module_name).to_owned());
421        )?
422        $(
423            result.type_handle = Some($registry.find_type($crate::types::TypeQuery::of::<$type>()).unwrap());
424        )?
425        $(
426            result.inputs.push(
427                $crate::function::FunctionParameter::new(
428                    stringify!($input_name).to_owned(),
429                    $registry.find_type($crate::types::TypeQuery::of::<$input_type>()).unwrap()
430                )
431            );
432        )*
433        $(
434            result.outputs.push(
435                $crate::function::FunctionParameter::new(
436                    stringify!($output_name).to_owned(),
437                    $registry.find_type($crate::types::TypeQuery::of::<$output_type>()).unwrap()
438                )
439            );
440        )*
441        result
442    }};
443}
444
445#[macro_export]
446macro_rules! define_function {
447    (
448        $registry:expr
449        =>
450        $(mod $module_name:ident)?
451        $(type ($type:ty))?
452        fn
453        $name:ident
454        ($( $input_name:ident : $input_type:ty),*)
455        ->
456        ($( $output_name:ident : $output_type:ty),*)
457        $code:block
458    ) => {
459        $crate::function::Function::new(
460            $crate::function_signature! {
461                $registry
462                =>
463                $(mod $module_name)?
464                $(type ($type))?
465                fn
466                $name
467                ($($input_name : $input_type),*)
468                ->
469                ($($output_name : $output_type),*)
470            },
471            $crate::function::FunctionBody::closure(move |context, registry| {
472                use intuicio_data::data_stack::DataStackPack;
473                #[allow(unused_mut)]
474                let ($(mut $input_name,)*) = <($($input_type,)*)>::stack_pop(context.stack());
475                $code.stack_push_reversed(context.stack());
476            }),
477        )
478    };
479}
480
481#[cfg(test)]
482mod tests {
483    use crate as intuicio_core;
484    use crate::{context::*, function::*, registry::*, types::struct_type::*};
485    use intuicio_data;
486    use intuicio_derive::*;
487
488    #[intuicio_function(meta = "foo")]
489    fn function_meta() {}
490
491    #[test]
492    fn test_function() {
493        fn add(context: &mut Context, _: &Registry) {
494            let a = context.stack().pop::<i32>().unwrap();
495            let b = context.stack().pop::<i32>().unwrap();
496            context.stack().push(a + b);
497        }
498
499        let i32_handle = NativeStructBuilder::new::<i32>()
500            .build()
501            .into_type()
502            .into_handle();
503        let signature = FunctionSignature::new("add")
504            .with_input(FunctionParameter::new("a", i32_handle.clone()))
505            .with_input(FunctionParameter::new("b", i32_handle.clone()))
506            .with_output(FunctionParameter::new("result", i32_handle));
507        let function = Function::new(signature.to_owned(), FunctionBody::pointer(add));
508
509        assert!(FunctionQuery::default().is_valid(&signature));
510        assert!(
511            FunctionQuery {
512                name: Some("add".into()),
513                ..Default::default()
514            }
515            .is_valid(&signature)
516        );
517        assert!(
518            FunctionQuery {
519                name: Some("add".into()),
520                inputs: [
521                    FunctionQueryParameter {
522                        name: Some("a".into()),
523                        ..Default::default()
524                    },
525                    FunctionQueryParameter {
526                        name: Some("b".into()),
527                        ..Default::default()
528                    }
529                ]
530                .as_slice()
531                .into(),
532                outputs: [FunctionQueryParameter {
533                    name: Some("result".into()),
534                    ..Default::default()
535                }]
536                .as_slice()
537                .into(),
538                ..Default::default()
539            }
540            .is_valid(&signature)
541        );
542        assert!(
543            !FunctionQuery {
544                name: Some("add".into()),
545                inputs: [
546                    FunctionQueryParameter {
547                        name: Some("b".into()),
548                        ..Default::default()
549                    },
550                    FunctionQueryParameter {
551                        name: Some("a".into()),
552                        ..Default::default()
553                    }
554                ]
555                .as_slice()
556                .into(),
557                ..Default::default()
558            }
559            .is_valid(&signature)
560        );
561
562        let mut context = Context::new(10240, 10240);
563        let registry = Registry::default();
564
565        context.stack().push(2);
566        context.stack().push(40);
567        function.invoke(&mut context, &registry);
568        assert_eq!(context.stack().pop::<i32>().unwrap(), 42);
569
570        assert_eq!(
571            function_meta::define_signature(&registry).meta,
572            Some(Meta::Identifier("foo".to_owned()))
573        );
574    }
575}