intuicio_backend_vm/
scope.rs

1use crate::debugger::VmDebuggerHandle;
2use intuicio_core::{
3    context::Context,
4    function::FunctionBody,
5    registry::{Registry, RegistryHandle},
6    script::{ScriptExpression, ScriptFunctionGenerator, ScriptHandle, ScriptOperation},
7};
8use intuicio_data::managed::{ManagedLazy, ManagedRefMut};
9use typid::ID;
10
11pub type VmScopeSymbol = ID<()>;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum VmScopeResult {
15    Continue,
16    Completed,
17    Suspended,
18}
19
20impl VmScopeResult {
21    pub fn can_continue(self) -> bool {
22        self == VmScopeResult::Continue
23    }
24
25    pub fn is_completed(self) -> bool {
26        self == VmScopeResult::Completed
27    }
28
29    pub fn is_suspended(self) -> bool {
30        self == VmScopeResult::Suspended
31    }
32
33    pub fn can_progress(self) -> bool {
34        !self.is_completed()
35    }
36}
37
38pub struct VmScope<'a, SE: ScriptExpression> {
39    handle: ScriptHandle<'a, SE>,
40    symbol: VmScopeSymbol,
41    position: usize,
42    child: Option<Box<Self>>,
43    debugger: Option<VmDebuggerHandle<SE>>,
44}
45
46impl<'a, SE: ScriptExpression> VmScope<'a, SE> {
47    pub fn new(handle: ScriptHandle<'a, SE>, symbol: VmScopeSymbol) -> Self {
48        Self {
49            handle,
50            symbol,
51            position: 0,
52            child: None,
53            debugger: None,
54        }
55    }
56
57    /// # Safety
58    pub unsafe fn restore(mut self, position: usize, child: Option<Self>) -> Self {
59        self.position = position;
60        self.child = child.map(Box::new);
61        self
62    }
63
64    pub fn with_debugger(mut self, debugger: Option<VmDebuggerHandle<SE>>) -> Self {
65        self.debugger = debugger;
66        self
67    }
68
69    #[allow(clippy::type_complexity)]
70    pub fn into_inner(
71        self,
72    ) -> (
73        ScriptHandle<'a, SE>,
74        VmScopeSymbol,
75        usize,
76        Option<Box<Self>>,
77        Option<VmDebuggerHandle<SE>>,
78    ) {
79        (
80            self.handle,
81            self.symbol,
82            self.position,
83            self.child,
84            self.debugger,
85        )
86    }
87
88    pub fn symbol(&self) -> VmScopeSymbol {
89        self.symbol
90    }
91
92    pub fn position(&self) -> usize {
93        self.position
94    }
95
96    pub fn has_completed(&self) -> bool {
97        self.position >= self.handle.len()
98    }
99
100    pub fn child(&self) -> Option<&Self> {
101        self.child.as_deref()
102    }
103
104    pub fn run(&mut self, context: &mut Context, registry: &Registry) {
105        while self.step(context, registry).can_progress() {}
106    }
107
108    pub fn run_until_suspended(
109        &mut self,
110        context: &mut Context,
111        registry: &Registry,
112    ) -> VmScopeResult {
113        loop {
114            match self.step(context, registry) {
115                VmScopeResult::Continue => {}
116                result => return result,
117            }
118        }
119    }
120
121    pub fn step(&mut self, context: &mut Context, registry: &Registry) -> VmScopeResult {
122        if let Some(child) = &mut self.child {
123            match child.step(context, registry) {
124                VmScopeResult::Completed => {
125                    self.child = None;
126                }
127                result => return result,
128            }
129        }
130        if self.position == 0
131            && let Some(debugger) = self.debugger.as_ref()
132            && let Ok(mut debugger) = debugger.try_write()
133        {
134            debugger.on_enter_scope(self, context, registry);
135        }
136        let result = if let Some(operation) = self.handle.get(self.position) {
137            if let Some(debugger) = self.debugger.as_ref()
138                && let Ok(mut debugger) = debugger.try_write()
139            {
140                debugger.on_enter_operation(self, operation, self.position, context, registry);
141            }
142            let position = self.position;
143            let result = match operation {
144                ScriptOperation::None => {
145                    self.position += 1;
146                    VmScopeResult::Continue
147                }
148                ScriptOperation::Expression { expression } => {
149                    expression.evaluate(context, registry);
150                    self.position += 1;
151                    VmScopeResult::Continue
152                }
153                ScriptOperation::DefineRegister { query } => {
154                    let handle = registry
155                        .types()
156                        .find(|handle| query.is_valid(handle))
157                        .unwrap_or_else(|| {
158                            panic!("Could not define register for non-existent type: {query:#?}")
159                        });
160                    unsafe {
161                        context
162                            .registers()
163                            .push_register_raw(handle.type_hash(), *handle.layout())
164                    };
165                    self.position += 1;
166                    VmScopeResult::Continue
167                }
168                ScriptOperation::DropRegister { index } => {
169                    let index = context.absolute_register_index(*index);
170                    context
171                        .registers()
172                        .access_register(index)
173                        .unwrap_or_else(|| {
174                            panic!("Could not access non-existent register: {index}")
175                        })
176                        .free();
177                    self.position += 1;
178                    VmScopeResult::Continue
179                }
180                ScriptOperation::PushFromRegister { index } => {
181                    let index = context.absolute_register_index(*index);
182                    let (stack, registers) = context.stack_and_registers();
183                    let mut register = registers.access_register(index).unwrap_or_else(|| {
184                        panic!("Could not access non-existent register: {index}")
185                    });
186                    if !stack.push_from_register(&mut register) {
187                        panic!("Could not push data from register: {index}");
188                    }
189                    self.position += 1;
190                    VmScopeResult::Continue
191                }
192                ScriptOperation::PopToRegister { index } => {
193                    let index = context.absolute_register_index(*index);
194                    let (stack, registers) = context.stack_and_registers();
195                    let mut register = registers.access_register(index).unwrap_or_else(|| {
196                        panic!("Could not access non-existent register: {index}")
197                    });
198                    if !stack.pop_to_register(&mut register) {
199                        panic!("Could not pop data to register: {index}");
200                    }
201                    self.position += 1;
202                    VmScopeResult::Continue
203                }
204                ScriptOperation::MoveRegister { from, to } => {
205                    let from = context.absolute_register_index(*from);
206                    let to = context.absolute_register_index(*to);
207                    let (mut source, mut target) = context
208                        .registers()
209                        .access_registers_pair(from, to)
210                        .unwrap_or_else(|| {
211                            panic!("Could not access non-existent registers pair: {from} and {to}")
212                        });
213                    source.move_to(&mut target);
214                    self.position += 1;
215                    VmScopeResult::Continue
216                }
217                ScriptOperation::CallFunction { query } => {
218                    let handle = registry
219                        .functions()
220                        .find(|handle| query.is_valid(handle.signature()))
221                        .unwrap_or_else(|| {
222                            panic!("Could not call non-existent function: {query:#?}")
223                        });
224                    handle.invoke(context, registry);
225                    self.position += 1;
226                    VmScopeResult::Continue
227                }
228                ScriptOperation::BranchScope {
229                    scope_success,
230                    scope_failure,
231                } => {
232                    if context.stack().pop::<bool>().unwrap() {
233                        self.child = Some(Box::new(
234                            Self::new(scope_success.clone(), self.symbol)
235                                .with_debugger(self.debugger.clone()),
236                        ));
237                    } else if let Some(scope_failure) = scope_failure {
238                        self.child = Some(Box::new(
239                            Self::new(scope_failure.clone(), self.symbol)
240                                .with_debugger(self.debugger.clone()),
241                        ));
242                    }
243                    self.position += 1;
244                    VmScopeResult::Continue
245                }
246                ScriptOperation::LoopScope { scope } => {
247                    if !context.stack().pop::<bool>().unwrap() {
248                        self.position += 1;
249                    } else {
250                        self.child = Some(Box::new(
251                            Self::new(scope.clone(), self.symbol)
252                                .with_debugger(self.debugger.clone()),
253                        ));
254                    }
255                    VmScopeResult::Continue
256                }
257                ScriptOperation::PushScope { scope } => {
258                    context.store_registers();
259                    self.child = Some(Box::new(
260                        Self::new(scope.clone(), self.symbol).with_debugger(self.debugger.clone()),
261                    ));
262                    self.position += 1;
263                    VmScopeResult::Continue
264                }
265                ScriptOperation::PopScope => {
266                    context.restore_registers();
267                    self.position = self.handle.len();
268                    VmScopeResult::Completed
269                }
270                ScriptOperation::ContinueScopeConditionally => {
271                    if context.stack().pop::<bool>().unwrap() {
272                        self.position += 1;
273                        VmScopeResult::Continue
274                    } else {
275                        self.position = self.handle.len();
276                        VmScopeResult::Completed
277                    }
278                }
279                ScriptOperation::Suspend => {
280                    self.position += 1;
281                    VmScopeResult::Suspended
282                }
283            };
284            if let Some(debugger) = self.debugger.as_ref()
285                && let Ok(mut debugger) = debugger.try_write()
286            {
287                debugger.on_exit_operation(self, operation, position, context, registry);
288            }
289            result
290        } else {
291            VmScopeResult::Completed
292        };
293        if (!result.can_progress() || self.position >= self.handle.len())
294            && let Some(debugger) = self.debugger.as_ref()
295            && let Ok(mut debugger) = debugger.try_write()
296        {
297            debugger.on_exit_scope(self, context, registry);
298        }
299        result
300    }
301}
302
303impl<SE: ScriptExpression + 'static> ScriptFunctionGenerator<SE> for VmScope<'static, SE> {
304    type Input = Option<VmDebuggerHandle<SE>>;
305    type Output = VmScopeSymbol;
306
307    fn generate_function_body(
308        script: ScriptHandle<'static, SE>,
309        debugger: Self::Input,
310    ) -> Option<(FunctionBody, Self::Output)> {
311        let symbol = VmScopeSymbol::new();
312        Some((
313            FunctionBody::closure(move |context, registry| {
314                Self::new(script.clone(), symbol)
315                    .with_debugger(debugger.clone())
316                    .run(context, registry);
317            }),
318            symbol,
319        ))
320    }
321}
322
323impl<SE: ScriptExpression> Clone for VmScope<'_, SE> {
324    fn clone(&self) -> Self {
325        Self {
326            handle: self.handle.clone(),
327            symbol: self.symbol,
328            position: self.position,
329            child: self.child.as_ref().map(|child| Box::new((**child).clone())),
330            debugger: self.debugger.clone(),
331        }
332    }
333}
334
335pub enum VmScopeFutureContext {
336    Owned(Box<Context>),
337    RefMut(ManagedRefMut<Context>),
338    Lazy(ManagedLazy<Context>),
339}
340
341impl From<Box<Context>> for VmScopeFutureContext {
342    fn from(value: Box<Context>) -> Self {
343        Self::Owned(value)
344    }
345}
346
347impl From<Context> for VmScopeFutureContext {
348    fn from(value: Context) -> Self {
349        Self::Owned(Box::new(value))
350    }
351}
352
353impl From<ManagedRefMut<Context>> for VmScopeFutureContext {
354    fn from(value: ManagedRefMut<Context>) -> Self {
355        Self::RefMut(value)
356    }
357}
358
359impl From<ManagedLazy<Context>> for VmScopeFutureContext {
360    fn from(value: ManagedLazy<Context>) -> Self {
361        Self::Lazy(value)
362    }
363}
364
365pub struct VmScopeFuture<'a, SE: ScriptExpression> {
366    pub scope: VmScope<'a, SE>,
367    pub context: VmScopeFutureContext,
368    pub registry: RegistryHandle,
369    pub operations_per_poll: usize,
370}
371
372impl<'a, SE: ScriptExpression> VmScopeFuture<'a, SE> {
373    pub fn new(
374        scope: VmScope<'a, SE>,
375        context: impl Into<VmScopeFutureContext>,
376        registry: RegistryHandle,
377    ) -> Self {
378        Self {
379            scope,
380            context: context.into(),
381            registry,
382            operations_per_poll: usize::MAX,
383        }
384    }
385
386    pub fn operations_per_poll(mut self, value: usize) -> Self {
387        self.operations_per_poll = value;
388        self
389    }
390
391    fn step(&mut self) -> Option<VmScopeResult> {
392        match &mut self.context {
393            VmScopeFutureContext::Owned(context) => {
394                Some(self.scope.step(&mut *context, &self.registry))
395            }
396            VmScopeFutureContext::RefMut(context) => {
397                let mut context = context.write()?;
398                Some(self.scope.step(&mut context, &self.registry))
399            }
400            VmScopeFutureContext::Lazy(context) => {
401                let mut context = context.write()?;
402                Some(self.scope.step(&mut context, &self.registry))
403            }
404        }
405    }
406}
407
408impl<SE: ScriptExpression> Future for VmScopeFuture<'_, SE> {
409    type Output = ();
410
411    fn poll(
412        mut self: std::pin::Pin<&mut Self>,
413        _cx: &mut std::task::Context<'_>,
414    ) -> std::task::Poll<Self::Output> {
415        for _ in 0..self.operations_per_poll {
416            match self.step() {
417                None => return std::task::Poll::Pending,
418                Some(VmScopeResult::Completed) => return std::task::Poll::Ready(()),
419                Some(VmScopeResult::Suspended) => return std::task::Poll::Pending,
420                Some(VmScopeResult::Continue) => {}
421            }
422        }
423        std::task::Poll::Pending
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use crate::scope::*;
430    use intuicio_core::{
431        Visibility,
432        function::{Function, FunctionParameter, FunctionQuery, FunctionSignature},
433        script::{ScriptBuilder, ScriptFunction, ScriptFunctionParameter, ScriptFunctionSignature},
434        types::{TypeQuery, struct_type::NativeStructBuilder},
435    };
436    use intuicio_data::managed::Managed;
437
438    #[test]
439    fn test_async() {
440        fn is_async<T: Send + Sync>() {}
441
442        is_async::<VmScope<()>>();
443        is_async::<VmScopeFuture<()>>();
444        is_async::<VmScopeFutureContext>();
445    }
446
447    #[test]
448    fn test_vm_scope() {
449        let i32_handle = NativeStructBuilder::new::<i32>()
450            .build()
451            .into_type()
452            .into_handle();
453        let mut registry = Registry::default().with_basic_types();
454        registry.add_function(Function::new(
455            FunctionSignature::new("add")
456                .with_input(FunctionParameter::new("a", i32_handle.clone()))
457                .with_input(FunctionParameter::new("b", i32_handle.clone()))
458                .with_output(FunctionParameter::new("result", i32_handle.clone())),
459            FunctionBody::closure(|context, _| {
460                let a = context.stack().pop::<i32>().unwrap();
461                let b = context.stack().pop::<i32>().unwrap();
462                context.stack().push(a + b);
463            }),
464        ));
465        registry.add_function(
466            VmScope::<()>::generate_function(
467                &ScriptFunction {
468                    signature: ScriptFunctionSignature {
469                        meta: None,
470                        name: "add_script".to_owned(),
471                        module_name: None,
472                        type_query: None,
473                        visibility: Visibility::Public,
474                        inputs: vec![
475                            ScriptFunctionParameter {
476                                meta: None,
477                                name: "a".to_owned(),
478                                type_query: TypeQuery::of::<i32>(),
479                            },
480                            ScriptFunctionParameter {
481                                meta: None,
482                                name: "b".to_owned(),
483                                type_query: TypeQuery::of::<i32>(),
484                            },
485                        ],
486                        outputs: vec![ScriptFunctionParameter {
487                            meta: None,
488                            name: "result".to_owned(),
489                            type_query: TypeQuery::of::<i32>(),
490                        }],
491                    },
492                    script: ScriptBuilder::<()>::default()
493                        .define_register(TypeQuery::of::<i32>())
494                        .pop_to_register(0)
495                        .push_from_register(0)
496                        .call_function(FunctionQuery {
497                            name: Some("add".into()),
498                            ..Default::default()
499                        })
500                        .build(),
501                },
502                &registry,
503                None,
504            )
505            .unwrap()
506            .0,
507        );
508        registry.add_type_handle(i32_handle);
509        let mut context = Context::new(10240, 10240);
510        let (result,) = registry
511            .find_function(FunctionQuery {
512                name: Some("add".into()),
513                ..Default::default()
514            })
515            .unwrap()
516            .call::<(i32,), _>(&mut context, &registry, (40, 2), true);
517        assert_eq!(result, 42);
518        assert_eq!(context.stack().position(), 0);
519        assert_eq!(context.registers().position(), 0);
520        let (result,) = registry
521            .find_function(FunctionQuery {
522                name: Some("add_script".into()),
523                ..Default::default()
524            })
525            .unwrap()
526            .call::<(i32,), _>(&mut context, &registry, (40, 2), true);
527        assert_eq!(result, 42);
528        assert_eq!(context.stack().position(), 0);
529        assert_eq!(context.registers().position(), 0);
530    }
531
532    #[test]
533    fn test_vm_scope_future() {
534        enum Expression {
535            Literal(i32),
536            Increment,
537        }
538
539        impl ScriptExpression for Expression {
540            fn evaluate(&self, context: &mut Context, _registry: &Registry) {
541                match self {
542                    Expression::Literal(value) => {
543                        context.stack().push(*value);
544                    }
545                    Expression::Increment => {
546                        let value = context.stack().pop::<i32>().unwrap();
547                        context.stack().push(value + 1);
548                    }
549                }
550            }
551        }
552
553        let mut context = Managed::new(Context::new(10240, 10240));
554        let registry = RegistryHandle::default();
555
556        let script = ScriptBuilder::<Expression>::default()
557            .expression(Expression::Literal(42))
558            .suspend()
559            .expression(Expression::Increment)
560            .build();
561        let scope = VmScope::new(script, VmScopeSymbol::new());
562        let mut future = VmScopeFuture::new(scope, context.lazy(), registry);
563        let mut future = std::pin::Pin::new(&mut future);
564        let mut cx = std::task::Context::from_waker(std::task::Waker::noop());
565        assert_eq!(context.write().unwrap().stack().position(), 0);
566
567        assert_eq!(future.as_mut().poll(&mut cx), std::task::Poll::Pending);
568        assert_eq!(
569            context.write().unwrap().stack().position(),
570            if cfg!(feature = "typehash_debug_name") {
571                28
572            } else {
573                12
574            }
575        );
576        assert_eq!(context.write().unwrap().stack().pop::<i32>().unwrap(), 42);
577        context.write().unwrap().stack().push(1);
578
579        assert_eq!(future.as_mut().poll(&mut cx), std::task::Poll::Ready(()));
580        assert_eq!(context.write().unwrap().stack().pop::<i32>().unwrap(), 2);
581    }
582}