harn-vm 0.8.16

Async bytecode virtual machine for the Harn programming language
Documentation
use std::collections::BTreeMap;
use std::rc::Rc;

use crate::value::{VmError, VmValue};

fn range_initial_done(start: i64, end: i64, inclusive: bool) -> bool {
    if inclusive {
        start > end
    } else {
        start >= end
    }
}

fn range_next(next: &mut i64, end: i64, inclusive: bool, done: &mut bool) -> Option<i64> {
    if *done {
        return None;
    }
    let value = *next;
    let at_end = if inclusive {
        value >= end
    } else {
        value
            .checked_add(1)
            .is_none_or(|candidate| candidate >= end)
    };
    if at_end {
        *done = true;
    } else {
        *next += 1;
    }
    Some(value)
}

impl super::super::Vm {
    pub(super) fn execute_iter_init(&mut self) -> Result<(), VmError> {
        let iterable = self.pop()?;
        match iterable {
            VmValue::List(items) => {
                self.iterators
                    .push(super::super::IterState::Vec { items, idx: 0 });
            }
            VmValue::Dict(map) => {
                let keys = map.keys().cloned().collect();
                self.iterators.push(super::super::IterState::Dict {
                    entries: map,
                    keys,
                    idx: 0,
                });
            }
            VmValue::Set(items) => {
                self.iterators
                    .push(super::super::IterState::Vec { items, idx: 0 });
            }
            VmValue::Channel(ch) => {
                self.iterators.push(super::super::IterState::Channel {
                    receiver: ch.receiver.clone(),
                    closed: ch.closed.clone(),
                });
            }
            VmValue::Generator(gen) => {
                self.iterators
                    .push(super::super::IterState::Generator { gen });
            }
            VmValue::Stream(stream) => {
                self.iterators
                    .push(super::super::IterState::Stream { stream });
            }
            VmValue::Range(r) => {
                self.iterators.push(super::super::IterState::Range {
                    next: r.start,
                    end: r.end,
                    inclusive: r.inclusive,
                    done: range_initial_done(r.start, r.end, r.inclusive),
                });
            }
            VmValue::Iter(handle) => {
                self.iterators
                    .push(super::super::IterState::VmIter { handle });
            }
            _ => {
                self.iterators.push(super::super::IterState::Vec {
                    items: Rc::new(Vec::new()),
                    idx: 0,
                });
            }
        }
        Ok(())
    }

    pub(super) async fn execute_iter_next(&mut self) -> Result<(), VmError> {
        let frame = self.frames.last_mut().unwrap();
        let target = frame.chunk.read_u16(frame.ip) as usize;
        frame.ip += 2;
        // Clone the handle so we don't hold a borrow on self.iterators across
        // the async next() call.
        let vm_iter_handle = match self.iterators.last() {
            Some(super::super::IterState::VmIter { handle }) => Some(handle.clone()),
            _ => None,
        };
        if let Some(handle) = vm_iter_handle {
            // Safe for recursive VM reentry via closures as long as they don't
            // re-enter the same iter handle.
            let next_val = crate::vm::iter::next_handle(&handle, self).await?;
            match next_val {
                Some(v) => self.stack.push(v),
                None => {
                    self.iterators.pop();
                    let frame = self.frames.last_mut().unwrap();
                    frame.ip = target;
                }
            }
        } else if let Some(state) = self.iterators.last_mut() {
            match state {
                super::super::IterState::Vec { items, idx } => {
                    if *idx < items.len() {
                        let item = items[*idx].clone();
                        *idx += 1;
                        self.stack.push(item);
                    } else {
                        self.iterators.pop();
                        let frame = self.frames.last_mut().unwrap();
                        frame.ip = target;
                    }
                }
                super::super::IterState::Dict { entries, keys, idx } => {
                    if *idx < keys.len() {
                        let key = &keys[*idx];
                        let value = entries.get(key).cloned().unwrap_or(VmValue::Nil);
                        *idx += 1;
                        self.stack.push(VmValue::Dict(Rc::new(BTreeMap::from([
                            ("key".to_string(), VmValue::String(Rc::from(key.as_str()))),
                            ("value".to_string(), value),
                        ]))));
                    } else {
                        self.iterators.pop();
                        let frame = self.frames.last_mut().unwrap();
                        frame.ip = target;
                    }
                }
                super::super::IterState::Channel { receiver, closed } => {
                    let rx = receiver.clone();
                    let is_closed = closed.load(std::sync::atomic::Ordering::Relaxed);
                    let mut guard = rx.lock().await;
                    // Closed sender: drain without blocking.
                    let item = if is_closed {
                        guard.try_recv().ok()
                    } else {
                        guard.recv().await
                    };
                    match item {
                        Some(val) => {
                            self.stack.push(val);
                        }
                        None => {
                            drop(guard);
                            self.iterators.pop();
                            let frame = self.frames.last_mut().unwrap();
                            frame.ip = target;
                        }
                    }
                }
                super::super::IterState::Range {
                    next,
                    end,
                    inclusive,
                    done,
                } => {
                    if let Some(v) = range_next(next, *end, *inclusive, done) {
                        self.stack.push(VmValue::Int(v));
                    } else {
                        self.iterators.pop();
                        let frame = self.frames.last_mut().unwrap();
                        frame.ip = target;
                    }
                }
                super::super::IterState::Generator { gen } => {
                    if gen.done.get() {
                        self.iterators.pop();
                        let frame = self.frames.last_mut().unwrap();
                        frame.ip = target;
                    } else {
                        let rx = gen.receiver.clone();
                        let mut guard = rx.lock().await;
                        match guard.recv().await {
                            Some(Ok(val)) => {
                                self.stack.push(val);
                            }
                            Some(Err(error)) => {
                                gen.done.set(true);
                                drop(guard);
                                self.iterators.pop();
                                return Err(error);
                            }
                            None => {
                                gen.done.set(true);
                                drop(guard);
                                self.iterators.pop();
                                let frame = self.frames.last_mut().unwrap();
                                frame.ip = target;
                            }
                        }
                    }
                }
                super::super::IterState::Stream { stream } => {
                    if stream.done.get() {
                        self.iterators.pop();
                        let frame = self.frames.last_mut().unwrap();
                        frame.ip = target;
                    } else {
                        let rx = stream.receiver.clone();
                        let mut guard = rx.lock().await;
                        match guard.recv().await {
                            Some(Ok(val)) => {
                                self.stack.push(val);
                            }
                            Some(Err(error)) => {
                                stream.done.set(true);
                                drop(guard);
                                self.iterators.pop();
                                return Err(error);
                            }
                            None => {
                                stream.done.set(true);
                                drop(guard);
                                self.iterators.pop();
                                let frame = self.frames.last_mut().unwrap();
                                frame.ip = target;
                            }
                        }
                    }
                }
                super::super::IterState::VmIter { .. } => {
                    unreachable!("VmIter state handled before this match");
                }
            }
        } else {
            let frame = self.frames.last_mut().unwrap();
            frame.ip = target;
        }
        Ok(())
    }

    pub(super) fn execute_pop_iterator(&mut self) {
        if let Some(super::super::IterState::Stream { stream }) = self.iterators.pop() {
            stream.cancel();
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::vm::{IterState, Vm};

    fn run_iter_init_test(test: impl std::future::Future<Output = ()>) {
        let rt = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .unwrap();
        rt.block_on(test);
    }

    #[test]
    fn iter_init_list_keeps_shared_backing_store() {
        run_iter_init_test(async {
            let items = Rc::new(vec![VmValue::Int(1), VmValue::Int(2)]);
            let mut vm = Vm::new();
            vm.stack.push(VmValue::List(items.clone()));

            vm.execute_iter_init().unwrap();

            match vm.iterators.last().unwrap() {
                IterState::Vec {
                    items: iter_items,
                    idx,
                } => {
                    assert!(Rc::ptr_eq(&items, iter_items));
                    assert_eq!(*idx, 0);
                }
                _ => panic!("expected vec iterator state"),
            }
        });
    }

    #[test]
    fn iter_init_set_keeps_shared_backing_store() {
        run_iter_init_test(async {
            let items = Rc::new(vec![VmValue::Int(1), VmValue::Int(2)]);
            let mut vm = Vm::new();
            vm.stack.push(VmValue::Set(items.clone()));

            vm.execute_iter_init().unwrap();

            match vm.iterators.last().unwrap() {
                IterState::Vec {
                    items: iter_items,
                    idx,
                } => {
                    assert!(Rc::ptr_eq(&items, iter_items));
                    assert_eq!(*idx, 0);
                }
                _ => panic!("expected vec iterator state"),
            }
        });
    }

    #[test]
    fn iter_init_dict_keeps_shared_entries_and_snapshots_keys() {
        run_iter_init_test(async {
            let entries = Rc::new(BTreeMap::from([
                ("a".to_string(), VmValue::Int(1)),
                ("b".to_string(), VmValue::Int(2)),
            ]));
            let mut vm = Vm::new();
            vm.stack.push(VmValue::Dict(entries.clone()));

            vm.execute_iter_init().unwrap();

            match vm.iterators.last().unwrap() {
                IterState::Dict {
                    entries: iter_entries,
                    keys,
                    idx,
                } => {
                    assert!(Rc::ptr_eq(&entries, iter_entries));
                    assert_eq!(keys.as_slice(), ["a".to_string(), "b".to_string()]);
                    assert_eq!(*idx, 0);
                }
                _ => panic!("expected dict iterator state"),
            }
        });
    }

    #[test]
    fn iter_init_inclusive_range_at_i64_max_is_not_empty() {
        run_iter_init_test(async {
            let mut vm = Vm::new();
            vm.stack.push(VmValue::Range(crate::value::VmRange {
                start: i64::MAX,
                end: i64::MAX,
                inclusive: true,
            }));

            vm.execute_iter_init().unwrap();

            match vm.iterators.last().unwrap() {
                IterState::Range {
                    next,
                    end,
                    inclusive,
                    done,
                } => {
                    assert_eq!(*next, i64::MAX);
                    assert_eq!(*end, i64::MAX);
                    assert!(*inclusive);
                    assert!(!*done);
                }
                _ => panic!("expected range iterator state"),
            }
        });
    }
}