Skip to main content

stak_vm/
vm.rs

1#[cfg(feature = "profile")]
2use crate::profiler::Profiler;
3use crate::{
4    Error, Exception, StackSlot,
5    code::{INTEGER_BASE, NUMBER_BASE, SHARE_BASE, TAG_BASE},
6    cons::{Cons, NEVER},
7    heap::Heap,
8    instruction::Instruction,
9    memory::Memory,
10    number::Number,
11    primitive_set::PrimitiveSet,
12    r#type::Type,
13    value::{TypedValue, Value},
14};
15#[cfg(feature = "profile")]
16use core::cell::RefCell;
17use core::{
18    fmt::{self, Display, Formatter, Write},
19    marker::PhantomData,
20};
21use stak_lzss::{Lzss, MAX_WINDOW_SIZE};
22use stak_util::block_on;
23use winter_maybe_async::{maybe_async, maybe_await};
24
25macro_rules! trace {
26    ($prefix:literal, $data:expr) => {
27        #[cfg(feature = "trace_instruction")]
28        std::eprintln!("{}: {}", $prefix, $data);
29    };
30}
31
32macro_rules! trace_memory {
33    ($self:expr) => {
34        #[cfg(feature = "trace_memory")]
35        std::eprintln!("{}", $self);
36    };
37}
38
39macro_rules! profile_event {
40    ($self:expr, $name:literal) => {
41        #[cfg(feature = "profile")]
42        (&$self).profile_event($name)?;
43    };
44}
45
46#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47struct Arity {
48    // A count does not include a variadic argument.
49    count: usize,
50    variadic: bool,
51}
52
53/// A virtual machine.
54pub struct Vm<'a, T: PrimitiveSet<H>, H = &'a mut [Value]> {
55    primitive_set: T,
56    memory: Memory<H>,
57    #[cfg(feature = "profile")]
58    profiler: Option<RefCell<&'a mut dyn Profiler<H>>>,
59    _profiler: PhantomData<&'a ()>,
60}
61
62// Note that some routines look unnecessarily complicated as we need to mark all
63// volatile variables live across garbage collections.
64impl<'a, T: PrimitiveSet<H>, H: Heap> Vm<'a, T, H> {
65    /// Creates a virtual machine.
66    pub fn new(heap: H, primitive_set: T) -> Result<Self, Error> {
67        Ok(Self {
68            primitive_set,
69            memory: Memory::new(heap)?,
70            #[cfg(feature = "profile")]
71            profiler: None,
72            _profiler: Default::default(),
73        })
74    }
75
76    /// Sets a profiler.
77    #[cfg(feature = "profile")]
78    pub fn with_profiler(self, profiler: &'a mut dyn Profiler<H>) -> Self {
79        Self {
80            profiler: Some(profiler.into()),
81            ..self
82        }
83    }
84
85    /// Returns a reference to a primitive set.
86    pub const fn primitive_set(&self) -> &T {
87        &self.primitive_set
88    }
89
90    /// Returns a mutable reference to a primitive set.
91    pub const fn primitive_set_mut(&mut self) -> &mut T {
92        &mut self.primitive_set
93    }
94
95    /// Runs bytecode on a virtual machine synchronously.
96    ///
97    /// # Panics
98    ///
99    /// Panics if asynchronous operations occur during the run.
100    pub fn run(&mut self, input: impl IntoIterator<Item = u8>) -> Result<(), T::Error> {
101        block_on!(self.run_async(input))
102    }
103
104    /// Runs bytecode on a virtual machine.
105    #[cfg_attr(not(feature = "async"), doc(hidden))]
106    #[maybe_async]
107    pub fn run_async(&mut self, input: impl IntoIterator<Item = u8>) -> Result<(), T::Error> {
108        self.initialize(input)?;
109
110        while let Err(error) = maybe_await!(self.run_with_continuation()) {
111            if error.is_critical() {
112                return Err(error);
113            }
114
115            let Some(continuation) = self.memory.cdr(self.memory.null()?)?.to_cons() else {
116                return Err(error);
117            };
118
119            if self.memory.cdr(continuation)?.tag() != Type::Procedure as _ {
120                return Err(error);
121            }
122
123            self.memory.set_register(continuation);
124            let string = self.memory.build_string("")?;
125            let symbol = self.memory.allocate(
126                self.memory.register().into(),
127                string.set_tag(Type::Symbol as _).into(),
128            )?;
129            let code = self.memory.allocate(
130                symbol.into(),
131                self.memory
132                    .code()
133                    .set_tag(
134                        Instruction::Call as u16
135                            + Self::build_arity(Arity {
136                                count: 1,
137                                variadic: false,
138                            }) as u16,
139                    )
140                    .into(),
141            )?;
142            self.memory.set_code(code);
143
144            self.memory.set_register(self.memory.null()?);
145            write!(&mut self.memory, "{error}").map_err(Error::from)?;
146            let code = self.memory.allocate(
147                self.memory.register().into(),
148                self.memory
149                    .code()
150                    .set_tag(Instruction::Constant as _)
151                    .into(),
152            )?;
153            self.memory.set_code(code);
154        }
155
156        Ok(())
157    }
158
159    #[maybe_async]
160    fn run_with_continuation(&mut self) -> Result<(), T::Error> {
161        while self.memory.code() != self.memory.null()? {
162            let instruction = self.memory.cdr(self.memory.code())?.assume_cons();
163
164            trace!("instruction", instruction.tag());
165
166            match instruction.tag() {
167                Instruction::CONSTANT => self.constant()?,
168                Instruction::GET => self.get()?,
169                Instruction::SET => self.set()?,
170                Instruction::IF => self.r#if()?,
171                code => maybe_await!(
172                    self.call(instruction, code as usize - Instruction::CALL as usize)
173                )?,
174            }
175
176            self.advance_code()?;
177
178            trace_memory!(self);
179        }
180
181        Ok(())
182    }
183
184    fn constant(&mut self) -> Result<(), Error> {
185        let constant = self.operand()?;
186
187        trace!("constant", constant);
188
189        self.memory.push(constant)?;
190
191        Ok(())
192    }
193
194    fn get(&mut self) -> Result<(), Error> {
195        let operand = self.operand_cons()?;
196        let value = self.memory.car(operand)?;
197
198        trace!("operand", operand);
199        trace!("value", value);
200
201        self.memory.push(value)?;
202
203        Ok(())
204    }
205
206    fn set(&mut self) -> Result<(), Error> {
207        let operand = self.operand_cons()?;
208        let value = self.memory.pop()?;
209
210        trace!("operand", operand);
211        trace!("value", value);
212
213        self.memory.set_car(operand, value)?;
214
215        Ok(())
216    }
217
218    fn r#if(&mut self) -> Result<(), Error> {
219        let cons = self.memory.stack();
220
221        if self.memory.pop()? != self.memory.boolean(false)?.into() {
222            self.memory.set_cdr(cons, self.operand()?)?;
223            self.memory.set_code(cons);
224        }
225
226        Ok(())
227    }
228
229    #[maybe_async]
230    fn call(&mut self, instruction: Cons, arity: usize) -> Result<(), T::Error> {
231        let procedure = self.procedure()?;
232
233        trace!("procedure", procedure);
234
235        if self.environment(procedure)?.tag() != Type::Procedure as _ {
236            return Err(Error::ProcedureExpected.into());
237        }
238
239        let arguments = Self::parse_arity(arity);
240        let r#return = instruction == self.memory.null()?;
241
242        trace!("return", r#return);
243
244        match self.code(procedure)?.to_typed() {
245            TypedValue::Cons(code) => {
246                #[cfg(feature = "profile")]
247                self.profile_call(self.memory.code(), r#return)?;
248
249                let parameters =
250                    Self::parse_arity(self.memory.car(code)?.assume_number().to_i64() as usize);
251
252                trace!("argument count", arguments.count);
253                trace!("argument variadic", arguments.variadic);
254                trace!("parameter count", parameters.count);
255                trace!("parameter variadic", parameters.variadic);
256
257                self.memory.set_register(procedure);
258
259                let mut list = if arguments.variadic {
260                    self.memory.pop()?.assume_cons()
261                } else {
262                    self.memory.null()?
263                };
264
265                for _ in 0..arguments.count {
266                    let value = self.memory.pop()?;
267                    list = self.memory.cons(value, list)?;
268                }
269
270                // Use a `code` field as an escape cell for a procedure.
271                let code = self.memory.code();
272                self.memory.set_code(self.memory.register());
273                self.memory.set_register(list);
274
275                let continuation = if r#return {
276                    self.continuation()?
277                } else {
278                    self.memory
279                        .allocate(code.into(), self.memory.stack().into())?
280                };
281                let stack = self.memory.allocate(
282                    continuation.into(),
283                    self.environment(self.memory.code())?
284                        .set_tag(StackSlot::Frame as _)
285                        .into(),
286                )?;
287                self.memory.set_stack(stack);
288                self.memory
289                    .set_code(self.code(self.memory.code())?.assume_cons());
290
291                for _ in 0..parameters.count {
292                    if self.memory.register() == self.memory.null()? {
293                        return Err(Error::ArgumentCount.into());
294                    }
295
296                    self.memory.push(self.memory.car(self.memory.register())?)?;
297                    self.memory
298                        .set_register(self.memory.cdr(self.memory.register())?.assume_cons());
299                }
300
301                if parameters.variadic {
302                    self.memory.push(self.memory.register().into())?;
303                } else if self.memory.register() != self.memory.null()? {
304                    return Err(Error::ArgumentCount.into());
305                }
306            }
307            TypedValue::Number(primitive) => {
308                if arguments.variadic {
309                    let list = self.memory.pop()?.assume_cons();
310                    self.memory.set_register(list);
311
312                    while self.memory.register() != self.memory.null()? {
313                        self.memory.push(self.memory.car(self.memory.register())?)?;
314                        self.memory
315                            .set_register(self.memory.cdr(self.memory.register())?.assume_cons());
316                    }
317                }
318
319                maybe_await!(
320                    self.primitive_set
321                        .operate(&mut self.memory, primitive.to_i64() as _)
322                )?;
323            }
324        }
325
326        Ok(())
327    }
328
329    const fn parse_arity(info: usize) -> Arity {
330        Arity {
331            count: info / 2,
332            variadic: info % 2 == 1,
333        }
334    }
335
336    const fn build_arity(arity: Arity) -> usize {
337        2 * arity.count + arity.variadic as usize
338    }
339
340    fn advance_code(&mut self) -> Result<(), Error> {
341        let mut code = self.memory.cdr(self.memory.code())?.assume_cons();
342
343        if code == self.memory.null()? {
344            #[cfg(feature = "profile")]
345            self.profile_return()?;
346
347            let continuation = self.continuation()?;
348            // Keep a value at the top of a stack.
349            self.memory
350                .set_cdr(self.memory.stack(), self.memory.cdr(continuation)?)?;
351
352            code = self
353                .memory
354                .cdr(self.memory.car(continuation)?.assume_cons())?
355                .assume_cons();
356        }
357
358        self.memory.set_code(code);
359
360        Ok(())
361    }
362
363    fn operand(&self) -> Result<Value, Error> {
364        self.memory.car(self.memory.code())
365    }
366
367    fn operand_cons(&self) -> Result<Cons, Error> {
368        Ok(match self.operand()?.to_typed() {
369            TypedValue::Cons(cons) => cons,
370            TypedValue::Number(index) => {
371                self.memory.tail(self.memory.stack(), index.to_i64() as _)?
372            }
373        })
374    }
375
376    // (code . environment)
377    fn procedure(&self) -> Result<Cons, Error> {
378        Ok(self.memory.car(self.operand_cons()?)?.assume_cons())
379    }
380
381    // (parameter-count . instruction-list) | primitive-id
382    fn code(&self, procedure: Cons) -> Result<Value, Error> {
383        self.memory.car(procedure)
384    }
385
386    fn environment(&self, procedure: Cons) -> Result<Cons, Error> {
387        Ok(self.memory.cdr(procedure)?.assume_cons())
388    }
389
390    // (code . stack)
391    fn continuation(&self) -> Result<Cons, Error> {
392        let mut stack = self.memory.stack();
393
394        while self.memory.cdr(stack)?.assume_cons().tag() != StackSlot::Frame as _ {
395            stack = self.memory.cdr(stack)?.assume_cons();
396        }
397
398        Ok(self.memory.car(stack)?.assume_cons())
399    }
400
401    // Profiling
402
403    #[cfg(feature = "profile")]
404    fn profile_call(&self, call_code: Cons, r#return: bool) -> Result<(), Error> {
405        if let Some(profiler) = &self.profiler {
406            profiler
407                .borrow_mut()
408                .profile_call(&self.memory, call_code, r#return)?;
409        }
410
411        Ok(())
412    }
413
414    #[cfg(feature = "profile")]
415    fn profile_return(&self) -> Result<(), Error> {
416        if let Some(profiler) = &self.profiler {
417            profiler.borrow_mut().profile_return(&self.memory)?;
418        }
419
420        Ok(())
421    }
422
423    #[cfg(feature = "profile")]
424    fn profile_event(&self, name: &str) -> Result<(), Error> {
425        if let Some(profiler) = &self.profiler {
426            profiler.borrow_mut().profile_event(name)?;
427        }
428
429        Ok(())
430    }
431
432    // This function is public only for benchmarking.
433    #[doc(hidden)]
434    pub fn initialize(&mut self, input: impl IntoIterator<Item = u8>) -> Result<(), super::Error> {
435        profile_event!(self, "initialization_start");
436        profile_event!(self, "decode_start");
437
438        let program = self.decode_ribs(input.into_iter())?;
439        self.memory
440            .set_false(self.memory.car(program)?.assume_cons())?;
441        self.memory
442            .set_code(self.memory.cdr(program)?.assume_cons());
443
444        profile_event!(self, "decode_end");
445
446        // Initialize an implicit top-level frame.
447        let codes = self
448            .memory
449            .cons(Number::default().into(), self.memory.null()?)?
450            .into();
451        let continuation = self.memory.cons(codes, self.memory.null()?)?.into();
452        let stack = self.memory.allocate(
453            continuation,
454            self.memory.null()?.set_tag(StackSlot::Frame as _).into(),
455        )?;
456        self.memory.set_stack(stack);
457        self.memory.set_register(NEVER);
458
459        profile_event!(self, "initialization_end");
460
461        Ok(())
462    }
463
464    fn decode_ribs(&mut self, input: impl Iterator<Item = u8>) -> Result<Cons, Error> {
465        let mut input = input.decompress::<{ MAX_WINDOW_SIZE }>();
466
467        while let Some(head) = input.next() {
468            if head & 0b1 == 0 {
469                let head = head >> 1;
470
471                if head == 0 {
472                    let value = self.memory.top()?;
473                    let cons = self.memory.cons(value, self.memory.code())?;
474                    self.memory.set_code(cons);
475                } else {
476                    let integer = Self::decode_integer_tail(&mut input, head - 1, SHARE_BASE)?;
477                    let index = integer >> 1;
478
479                    if index > 0 {
480                        let cons = self.memory.tail(self.memory.code(), index as usize - 1)?;
481                        let head = self.memory.cdr(cons)?.assume_cons();
482                        let tail = self.memory.cdr(head)?;
483                        self.memory.set_cdr(head, self.memory.code().into())?;
484                        self.memory.set_cdr(cons, tail)?;
485                        self.memory.set_code(head);
486                    }
487
488                    let value = self.memory.car(self.memory.code())?;
489
490                    if integer & 1 == 0 {
491                        self.memory
492                            .set_code(self.memory.cdr(self.memory.code())?.assume_cons());
493                    }
494
495                    self.memory.push(value)?;
496                }
497            } else if head & 0b10 == 0 {
498                let cons = self.memory.stack();
499                let cdr = self.memory.pop()?;
500                let car = self.memory.top()?;
501                let tag = Self::decode_integer_tail(&mut input, head >> 2, TAG_BASE)?;
502                self.memory.set_car(cons, car)?;
503                self.memory.set_raw_cdr(cons, cdr.set_tag(tag as _))?;
504                self.memory.set_top(cons.into())?;
505            } else {
506                self.memory.push(
507                    Self::decode_number(Self::decode_integer_tail(
508                        &mut input,
509                        head >> 2,
510                        NUMBER_BASE,
511                    )?)
512                    .into(),
513                )?;
514            }
515        }
516
517        self.memory.pop()?.to_cons().ok_or(Error::BytecodeEnd)
518    }
519
520    fn decode_number(integer: u128) -> Number {
521        if integer & 1 == 0 {
522            Number::from_i64((integer >> 1) as _)
523        } else if integer & 0b10 == 0 {
524            Number::from_i64(-((integer >> 2) as i64))
525        } else {
526            let integer = integer >> 2;
527            let mantissa =
528                if integer.is_multiple_of(2) { 1.0 } else { -1.0 } * (integer >> 12) as f64;
529            let exponent = ((integer >> 1) % (1 << 11)) as isize - 1023;
530
531            Number::from_f64(if exponent < 0 {
532                mantissa / (1u64 << exponent.abs()) as f64
533            } else {
534                mantissa * (1u64 << exponent) as f64
535            })
536        }
537    }
538
539    fn decode_integer_tail(
540        input: &mut impl Iterator<Item = u8>,
541        mut x: u8,
542        mut base: u128,
543    ) -> Result<u128, Error> {
544        let mut y = (x >> 1) as u128;
545
546        while x & 1 != 0 {
547            x = input.next().ok_or(Error::BytecodeEnd)?;
548            y += (x as u128 >> 1) * base;
549            base *= INTEGER_BASE;
550        }
551
552        Ok(y)
553    }
554}
555
556impl<T: PrimitiveSet<H>, H: Heap> Display for Vm<'_, T, H> {
557    fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
558        write!(formatter, "{}", &self.memory)
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    struct FakePrimitiveSet {}
567
568    impl<H: Heap> PrimitiveSet<H> for FakePrimitiveSet {
569        type Error = Error;
570
571        #[maybe_async]
572        fn operate(
573            &mut self,
574            _memory: &mut Memory<H>,
575            _primitive: usize,
576        ) -> Result<(), Self::Error> {
577            Ok(())
578        }
579    }
580
581    type VoidVm = Vm<'static, FakePrimitiveSet>;
582
583    #[test]
584    fn arity() {
585        for arity in [
586            Arity {
587                count: 0,
588                variadic: false,
589            },
590            Arity {
591                count: 1,
592                variadic: false,
593            },
594            Arity {
595                count: 2,
596                variadic: false,
597            },
598            Arity {
599                count: 0,
600                variadic: true,
601            },
602            Arity {
603                count: 1,
604                variadic: true,
605            },
606            Arity {
607                count: 2,
608                variadic: true,
609            },
610        ] {
611            assert_eq!(VoidVm::parse_arity(VoidVm::build_arity(arity)), arity);
612        }
613    }
614}