Skip to main content

harn_vm/
value.rs

1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9/// An async builtin function for the VM.
10pub type VmAsyncBuiltinFn = Rc<
11    dyn Fn(
12        Vec<VmValue>,
13    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16/// The raw join handle type for spawned tasks.
17pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19/// A spawned async task handle with cancellation support.
20pub struct VmTaskHandle {
21    pub handle: VmJoinHandle,
22    /// Cooperative cancellation token. Set to true to request graceful shutdown.
23    pub cancel_token: Arc<AtomicBool>,
24}
25
26/// A channel handle for the VM (uses tokio mpsc).
27#[derive(Debug, Clone)]
28pub struct VmChannelHandle {
29    pub name: String,
30    pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
31    pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
32    pub closed: Arc<AtomicBool>,
33}
34
35/// An atomic integer handle for the VM.
36#[derive(Debug, Clone)]
37pub struct VmAtomicHandle {
38    pub value: Arc<AtomicI64>,
39}
40
41/// VM runtime value.
42#[derive(Debug, Clone)]
43pub enum VmValue {
44    Int(i64),
45    Float(f64),
46    String(Rc<str>),
47    Bool(bool),
48    Nil,
49    List(Rc<Vec<VmValue>>),
50    Dict(Rc<BTreeMap<String, VmValue>>),
51    Closure(Rc<VmClosure>),
52    Duration(u64),
53    EnumVariant {
54        enum_name: String,
55        variant: String,
56        fields: Vec<VmValue>,
57    },
58    StructInstance {
59        struct_name: String,
60        fields: BTreeMap<String, VmValue>,
61    },
62    TaskHandle(String),
63    Channel(VmChannelHandle),
64    Atomic(VmAtomicHandle),
65    McpClient(VmMcpClientHandle),
66    Set(Rc<Vec<VmValue>>),
67}
68
69/// A compiled closure value.
70#[derive(Debug, Clone)]
71pub struct VmClosure {
72    pub func: CompiledFunction,
73    pub env: VmEnv,
74}
75
76/// VM environment for variable storage.
77#[derive(Debug, Clone)]
78pub struct VmEnv {
79    pub(crate) scopes: Vec<Scope>,
80}
81
82#[derive(Debug, Clone)]
83pub(crate) struct Scope {
84    pub(crate) vars: BTreeMap<String, (VmValue, bool)>, // (value, mutable)
85}
86
87impl Default for VmEnv {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl VmEnv {
94    pub fn new() -> Self {
95        Self {
96            scopes: vec![Scope {
97                vars: BTreeMap::new(),
98            }],
99        }
100    }
101
102    pub fn push_scope(&mut self) {
103        self.scopes.push(Scope {
104            vars: BTreeMap::new(),
105        });
106    }
107
108    pub fn get(&self, name: &str) -> Option<VmValue> {
109        for scope in self.scopes.iter().rev() {
110            if let Some((val, _)) = scope.vars.get(name) {
111                return Some(val.clone());
112            }
113        }
114        None
115    }
116
117    pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) {
118        if let Some(scope) = self.scopes.last_mut() {
119            scope.vars.insert(name.to_string(), (value, mutable));
120        }
121    }
122
123    pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
124        let mut vars = BTreeMap::new();
125        for scope in &self.scopes {
126            for (name, (value, _)) in &scope.vars {
127                vars.insert(name.clone(), value.clone());
128            }
129        }
130        vars
131    }
132
133    pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
134        for scope in self.scopes.iter_mut().rev() {
135            if let Some((_, mutable)) = scope.vars.get(name) {
136                if !mutable {
137                    return Err(VmError::ImmutableAssignment(name.to_string()));
138                }
139                scope.vars.insert(name.to_string(), (value, true));
140                return Ok(());
141            }
142        }
143        Err(VmError::UndefinedVariable(name.to_string()))
144    }
145}
146
147/// VM runtime errors.
148/// Compute Levenshtein edit distance between two strings.
149fn levenshtein(a: &str, b: &str) -> usize {
150    let a: Vec<char> = a.chars().collect();
151    let b: Vec<char> = b.chars().collect();
152    let (m, n) = (a.len(), b.len());
153    let mut prev = (0..=n).collect::<Vec<_>>();
154    let mut curr = vec![0; n + 1];
155    for i in 1..=m {
156        curr[0] = i;
157        for j in 1..=n {
158            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
159            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
160        }
161        std::mem::swap(&mut prev, &mut curr);
162    }
163    prev[n]
164}
165
166/// Find the closest match from a list of candidates using Levenshtein distance.
167/// Returns `Some(suggestion)` if a candidate is within `max_dist` edits.
168pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
169    let max_dist = match name.len() {
170        0..=2 => 1,
171        3..=5 => 2,
172        _ => 3,
173    };
174    candidates
175        .filter(|c| *c != name && !c.starts_with("__"))
176        .map(|c| (c, levenshtein(name, c)))
177        .filter(|(_, d)| *d <= max_dist)
178        // Prefer smallest distance, then closest length to original, then alphabetical
179        .min_by(|(a, da), (b, db)| {
180            da.cmp(db)
181                .then_with(|| {
182                    let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
183                    let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
184                    a_diff.cmp(&b_diff)
185                })
186                .then_with(|| a.cmp(b))
187        })
188        .map(|(c, _)| c.to_string())
189}
190
191#[derive(Debug, Clone)]
192pub enum VmError {
193    StackUnderflow,
194    StackOverflow,
195    UndefinedVariable(String),
196    UndefinedBuiltin(String),
197    ImmutableAssignment(String),
198    TypeError(String),
199    Runtime(String),
200    DivisionByZero,
201    Thrown(VmValue),
202    Return(VmValue),
203    InvalidInstruction(u8),
204}
205
206impl std::fmt::Display for VmError {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        match self {
209            VmError::StackUnderflow => write!(f, "Stack underflow"),
210            VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
211            VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
212            VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
213            VmError::ImmutableAssignment(n) => {
214                write!(f, "Cannot assign to immutable binding: {n}")
215            }
216            VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
217            VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
218            VmError::DivisionByZero => write!(f, "Division by zero"),
219            VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
220            VmError::Return(_) => write!(f, "Return from function"),
221            VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
222        }
223    }
224}
225
226impl std::error::Error for VmError {}
227
228impl VmValue {
229    pub fn is_truthy(&self) -> bool {
230        match self {
231            VmValue::Bool(b) => *b,
232            VmValue::Nil => false,
233            VmValue::Int(n) => *n != 0,
234            VmValue::Float(n) => *n != 0.0,
235            VmValue::String(s) => !s.is_empty(),
236            VmValue::List(l) => !l.is_empty(),
237            VmValue::Dict(d) => !d.is_empty(),
238            VmValue::Closure(_) => true,
239            VmValue::Duration(ms) => *ms > 0,
240            VmValue::EnumVariant { .. } => true,
241            VmValue::StructInstance { .. } => true,
242            VmValue::TaskHandle(_) => true,
243            VmValue::Channel(_) => true,
244            VmValue::Atomic(_) => true,
245            VmValue::McpClient(_) => true,
246            VmValue::Set(s) => !s.is_empty(),
247        }
248    }
249
250    pub fn type_name(&self) -> &'static str {
251        match self {
252            VmValue::String(_) => "string",
253            VmValue::Int(_) => "int",
254            VmValue::Float(_) => "float",
255            VmValue::Bool(_) => "bool",
256            VmValue::Nil => "nil",
257            VmValue::List(_) => "list",
258            VmValue::Dict(_) => "dict",
259            VmValue::Closure(_) => "closure",
260            VmValue::Duration(_) => "duration",
261            VmValue::EnumVariant { .. } => "enum",
262            VmValue::StructInstance { .. } => "struct",
263            VmValue::TaskHandle(_) => "task_handle",
264            VmValue::Channel(_) => "channel",
265            VmValue::Atomic(_) => "atomic",
266            VmValue::McpClient(_) => "mcp_client",
267            VmValue::Set(_) => "set",
268        }
269    }
270
271    pub fn display(&self) -> String {
272        match self {
273            VmValue::Int(n) => n.to_string(),
274            VmValue::Float(n) => {
275                if *n == (*n as i64) as f64 && n.abs() < 1e15 {
276                    format!("{:.1}", n)
277                } else {
278                    n.to_string()
279                }
280            }
281            VmValue::String(s) => s.to_string(),
282            VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
283            VmValue::Nil => "nil".to_string(),
284            VmValue::List(items) => {
285                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
286                format!("[{}]", inner.join(", "))
287            }
288            VmValue::Dict(map) => {
289                let inner: Vec<String> = map
290                    .iter()
291                    .map(|(k, v)| format!("{k}: {}", v.display()))
292                    .collect();
293                format!("{{{}}}", inner.join(", "))
294            }
295            VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
296            VmValue::Duration(ms) => {
297                if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
298                    format!("{}h", ms / 3_600_000)
299                } else if *ms >= 60_000 && ms % 60_000 == 0 {
300                    format!("{}m", ms / 60_000)
301                } else if *ms >= 1000 && ms % 1000 == 0 {
302                    format!("{}s", ms / 1000)
303                } else {
304                    format!("{}ms", ms)
305                }
306            }
307            VmValue::EnumVariant {
308                enum_name,
309                variant,
310                fields,
311            } => {
312                if fields.is_empty() {
313                    format!("{enum_name}.{variant}")
314                } else {
315                    let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
316                    format!("{enum_name}.{variant}({})", inner.join(", "))
317                }
318            }
319            VmValue::StructInstance {
320                struct_name,
321                fields,
322            } => {
323                let inner: Vec<String> = fields
324                    .iter()
325                    .map(|(k, v)| format!("{k}: {}", v.display()))
326                    .collect();
327                format!("{struct_name} {{{}}}", inner.join(", "))
328            }
329            VmValue::TaskHandle(id) => format!("<task:{id}>"),
330            VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
331            VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
332            VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
333            VmValue::Set(items) => {
334                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
335                format!("set({})", inner.join(", "))
336            }
337        }
338    }
339
340    /// Get the value as a BTreeMap reference, if it's a Dict.
341    pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
342        if let VmValue::Dict(d) = self {
343            Some(d)
344        } else {
345            None
346        }
347    }
348
349    pub fn as_int(&self) -> Option<i64> {
350        if let VmValue::Int(n) = self {
351            Some(*n)
352        } else {
353            None
354        }
355    }
356}
357
358/// Sync builtin function for the VM.
359pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
360
361pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
362    match (a, b) {
363        (VmValue::Int(x), VmValue::Int(y)) => x == y,
364        (VmValue::Float(x), VmValue::Float(y)) => x == y,
365        (VmValue::String(x), VmValue::String(y)) => x == y,
366        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
367        (VmValue::Nil, VmValue::Nil) => true,
368        (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
369        (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
370        (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
371        (VmValue::Channel(_), VmValue::Channel(_)) => false, // channels are never equal
372        (VmValue::Atomic(a), VmValue::Atomic(b)) => {
373            a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
374        }
375        (VmValue::List(a), VmValue::List(b)) => {
376            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
377        }
378        (VmValue::Dict(a), VmValue::Dict(b)) => {
379            a.len() == b.len()
380                && a.iter()
381                    .zip(b.iter())
382                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
383        }
384        (
385            VmValue::EnumVariant {
386                enum_name: a_e,
387                variant: a_v,
388                fields: a_f,
389            },
390            VmValue::EnumVariant {
391                enum_name: b_e,
392                variant: b_v,
393                fields: b_f,
394            },
395        ) => {
396            a_e == b_e
397                && a_v == b_v
398                && a_f.len() == b_f.len()
399                && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
400        }
401        (
402            VmValue::StructInstance {
403                struct_name: a_s,
404                fields: a_f,
405            },
406            VmValue::StructInstance {
407                struct_name: b_s,
408                fields: b_f,
409            },
410        ) => {
411            a_s == b_s
412                && a_f.len() == b_f.len()
413                && a_f
414                    .iter()
415                    .zip(b_f.iter())
416                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
417        }
418        (VmValue::Set(a), VmValue::Set(b)) => {
419            a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
420        }
421        _ => false,
422    }
423}
424
425pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
426    match (a, b) {
427        (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
428        (VmValue::Float(x), VmValue::Float(y)) => {
429            if x < y {
430                -1
431            } else if x > y {
432                1
433            } else {
434                0
435            }
436        }
437        (VmValue::Int(x), VmValue::Float(y)) => {
438            let x = *x as f64;
439            if x < *y {
440                -1
441            } else if x > *y {
442                1
443            } else {
444                0
445            }
446        }
447        (VmValue::Float(x), VmValue::Int(y)) => {
448            let y = *y as f64;
449            if *x < y {
450                -1
451            } else if *x > y {
452                1
453            } else {
454                0
455            }
456        }
457        (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
458        _ => 0,
459    }
460}