Skip to main content

harn_vm/
value.rs

1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9/// An async builtin function for the VM.
10pub type VmAsyncBuiltinFn = Rc<
11    dyn Fn(
12        Vec<VmValue>,
13    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16/// A spawned async task handle.
17pub type VmTaskHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19/// A channel handle for the VM (uses tokio mpsc).
20#[derive(Debug, Clone)]
21pub struct VmChannelHandle {
22    pub name: String,
23    pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
24    pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
25    pub closed: Arc<AtomicBool>,
26}
27
28/// An atomic integer handle for the VM.
29#[derive(Debug, Clone)]
30pub struct VmAtomicHandle {
31    pub value: Arc<AtomicI64>,
32}
33
34/// VM runtime value.
35#[derive(Debug, Clone)]
36pub enum VmValue {
37    Int(i64),
38    Float(f64),
39    String(Rc<str>),
40    Bool(bool),
41    Nil,
42    List(Rc<Vec<VmValue>>),
43    Dict(Rc<BTreeMap<String, VmValue>>),
44    Closure(Rc<VmClosure>),
45    Duration(u64),
46    EnumVariant {
47        enum_name: String,
48        variant: String,
49        fields: Vec<VmValue>,
50    },
51    StructInstance {
52        struct_name: String,
53        fields: BTreeMap<String, VmValue>,
54    },
55    TaskHandle(String),
56    Channel(VmChannelHandle),
57    Atomic(VmAtomicHandle),
58    McpClient(VmMcpClientHandle),
59}
60
61/// A compiled closure value.
62#[derive(Debug, Clone)]
63pub struct VmClosure {
64    pub func: CompiledFunction,
65    pub env: VmEnv,
66}
67
68/// VM environment for variable storage.
69#[derive(Debug, Clone)]
70pub struct VmEnv {
71    pub(crate) scopes: Vec<Scope>,
72}
73
74#[derive(Debug, Clone)]
75pub(crate) struct Scope {
76    pub(crate) vars: BTreeMap<String, (VmValue, bool)>, // (value, mutable)
77}
78
79impl Default for VmEnv {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl VmEnv {
86    pub fn new() -> Self {
87        Self {
88            scopes: vec![Scope {
89                vars: BTreeMap::new(),
90            }],
91        }
92    }
93
94    pub fn push_scope(&mut self) {
95        self.scopes.push(Scope {
96            vars: BTreeMap::new(),
97        });
98    }
99
100    pub fn get(&self, name: &str) -> Option<VmValue> {
101        for scope in self.scopes.iter().rev() {
102            if let Some((val, _)) = scope.vars.get(name) {
103                return Some(val.clone());
104            }
105        }
106        None
107    }
108
109    pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) {
110        if let Some(scope) = self.scopes.last_mut() {
111            scope.vars.insert(name.to_string(), (value, mutable));
112        }
113    }
114
115    pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
116        let mut vars = BTreeMap::new();
117        for scope in &self.scopes {
118            for (name, (value, _)) in &scope.vars {
119                vars.insert(name.clone(), value.clone());
120            }
121        }
122        vars
123    }
124
125    pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
126        for scope in self.scopes.iter_mut().rev() {
127            if let Some((_, mutable)) = scope.vars.get(name) {
128                if !mutable {
129                    return Err(VmError::ImmutableAssignment(name.to_string()));
130                }
131                scope.vars.insert(name.to_string(), (value, true));
132                return Ok(());
133            }
134        }
135        Err(VmError::UndefinedVariable(name.to_string()))
136    }
137}
138
139/// VM runtime errors.
140/// Compute Levenshtein edit distance between two strings.
141fn levenshtein(a: &str, b: &str) -> usize {
142    let a: Vec<char> = a.chars().collect();
143    let b: Vec<char> = b.chars().collect();
144    let (m, n) = (a.len(), b.len());
145    let mut prev = (0..=n).collect::<Vec<_>>();
146    let mut curr = vec![0; n + 1];
147    for i in 1..=m {
148        curr[0] = i;
149        for j in 1..=n {
150            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
151            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
152        }
153        std::mem::swap(&mut prev, &mut curr);
154    }
155    prev[n]
156}
157
158/// Find the closest match from a list of candidates using Levenshtein distance.
159/// Returns `Some(suggestion)` if a candidate is within `max_dist` edits.
160pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
161    let max_dist = match name.len() {
162        0..=2 => 1,
163        3..=5 => 2,
164        _ => 3,
165    };
166    candidates
167        .filter(|c| *c != name && !c.starts_with("__"))
168        .map(|c| (c, levenshtein(name, c)))
169        .filter(|(_, d)| *d <= max_dist)
170        // Prefer smallest distance, then closest length to original, then alphabetical
171        .min_by(|(a, da), (b, db)| {
172            da.cmp(db)
173                .then_with(|| {
174                    let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
175                    let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
176                    a_diff.cmp(&b_diff)
177                })
178                .then_with(|| a.cmp(b))
179        })
180        .map(|(c, _)| c.to_string())
181}
182
183#[derive(Debug, Clone)]
184pub enum VmError {
185    StackUnderflow,
186    StackOverflow,
187    UndefinedVariable(String),
188    UndefinedBuiltin(String),
189    ImmutableAssignment(String),
190    TypeError(String),
191    Runtime(String),
192    DivisionByZero,
193    Thrown(VmValue),
194    Return(VmValue),
195    InvalidInstruction(u8),
196}
197
198impl std::fmt::Display for VmError {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        match self {
201            VmError::StackUnderflow => write!(f, "Stack underflow"),
202            VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
203            VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
204            VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
205            VmError::ImmutableAssignment(n) => {
206                write!(f, "Cannot assign to immutable binding: {n}")
207            }
208            VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
209            VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
210            VmError::DivisionByZero => write!(f, "Division by zero"),
211            VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
212            VmError::Return(_) => write!(f, "Return from function"),
213            VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
214        }
215    }
216}
217
218impl std::error::Error for VmError {}
219
220impl VmValue {
221    pub fn is_truthy(&self) -> bool {
222        match self {
223            VmValue::Bool(b) => *b,
224            VmValue::Nil => false,
225            VmValue::Int(n) => *n != 0,
226            VmValue::Float(n) => *n != 0.0,
227            VmValue::String(s) => !s.is_empty(),
228            VmValue::List(l) => !l.is_empty(),
229            VmValue::Dict(d) => !d.is_empty(),
230            VmValue::Closure(_) => true,
231            VmValue::Duration(ms) => *ms > 0,
232            VmValue::EnumVariant { .. } => true,
233            VmValue::StructInstance { .. } => true,
234            VmValue::TaskHandle(_) => true,
235            VmValue::Channel(_) => true,
236            VmValue::Atomic(_) => true,
237            VmValue::McpClient(_) => true,
238        }
239    }
240
241    pub fn type_name(&self) -> &'static str {
242        match self {
243            VmValue::String(_) => "string",
244            VmValue::Int(_) => "int",
245            VmValue::Float(_) => "float",
246            VmValue::Bool(_) => "bool",
247            VmValue::Nil => "nil",
248            VmValue::List(_) => "list",
249            VmValue::Dict(_) => "dict",
250            VmValue::Closure(_) => "closure",
251            VmValue::Duration(_) => "duration",
252            VmValue::EnumVariant { .. } => "enum",
253            VmValue::StructInstance { .. } => "struct",
254            VmValue::TaskHandle(_) => "task_handle",
255            VmValue::Channel(_) => "channel",
256            VmValue::Atomic(_) => "atomic",
257            VmValue::McpClient(_) => "mcp_client",
258        }
259    }
260
261    pub fn display(&self) -> String {
262        match self {
263            VmValue::Int(n) => n.to_string(),
264            VmValue::Float(n) => {
265                if *n == (*n as i64) as f64 && n.abs() < 1e15 {
266                    format!("{:.1}", n)
267                } else {
268                    n.to_string()
269                }
270            }
271            VmValue::String(s) => s.to_string(),
272            VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
273            VmValue::Nil => "nil".to_string(),
274            VmValue::List(items) => {
275                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
276                format!("[{}]", inner.join(", "))
277            }
278            VmValue::Dict(map) => {
279                let inner: Vec<String> = map
280                    .iter()
281                    .map(|(k, v)| format!("{k}: {}", v.display()))
282                    .collect();
283                format!("{{{}}}", inner.join(", "))
284            }
285            VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
286            VmValue::Duration(ms) => {
287                if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
288                    format!("{}h", ms / 3_600_000)
289                } else if *ms >= 60_000 && ms % 60_000 == 0 {
290                    format!("{}m", ms / 60_000)
291                } else if *ms >= 1000 && ms % 1000 == 0 {
292                    format!("{}s", ms / 1000)
293                } else {
294                    format!("{}ms", ms)
295                }
296            }
297            VmValue::EnumVariant {
298                enum_name,
299                variant,
300                fields,
301            } => {
302                if fields.is_empty() {
303                    format!("{enum_name}.{variant}")
304                } else {
305                    let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
306                    format!("{enum_name}.{variant}({})", inner.join(", "))
307                }
308            }
309            VmValue::StructInstance {
310                struct_name,
311                fields,
312            } => {
313                let inner: Vec<String> = fields
314                    .iter()
315                    .map(|(k, v)| format!("{k}: {}", v.display()))
316                    .collect();
317                format!("{struct_name} {{{}}}", inner.join(", "))
318            }
319            VmValue::TaskHandle(id) => format!("<task:{id}>"),
320            VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
321            VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
322            VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
323        }
324    }
325
326    /// Get the value as a BTreeMap reference, if it's a Dict.
327    pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
328        if let VmValue::Dict(d) = self {
329            Some(d)
330        } else {
331            None
332        }
333    }
334
335    pub fn as_int(&self) -> Option<i64> {
336        if let VmValue::Int(n) = self {
337            Some(*n)
338        } else {
339            None
340        }
341    }
342}
343
344/// Sync builtin function for the VM.
345pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
346
347pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
348    match (a, b) {
349        (VmValue::Int(x), VmValue::Int(y)) => x == y,
350        (VmValue::Float(x), VmValue::Float(y)) => x == y,
351        (VmValue::String(x), VmValue::String(y)) => x == y,
352        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
353        (VmValue::Nil, VmValue::Nil) => true,
354        (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
355        (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
356        (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
357        (VmValue::Channel(_), VmValue::Channel(_)) => false, // channels are never equal
358        (VmValue::Atomic(a), VmValue::Atomic(b)) => {
359            a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
360        }
361        (VmValue::List(a), VmValue::List(b)) => {
362            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
363        }
364        (VmValue::Dict(a), VmValue::Dict(b)) => {
365            a.len() == b.len()
366                && a.iter()
367                    .zip(b.iter())
368                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
369        }
370        (
371            VmValue::EnumVariant {
372                enum_name: a_e,
373                variant: a_v,
374                fields: a_f,
375            },
376            VmValue::EnumVariant {
377                enum_name: b_e,
378                variant: b_v,
379                fields: b_f,
380            },
381        ) => {
382            a_e == b_e
383                && a_v == b_v
384                && a_f.len() == b_f.len()
385                && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
386        }
387        (
388            VmValue::StructInstance {
389                struct_name: a_s,
390                fields: a_f,
391            },
392            VmValue::StructInstance {
393                struct_name: b_s,
394                fields: b_f,
395            },
396        ) => {
397            a_s == b_s
398                && a_f.len() == b_f.len()
399                && a_f
400                    .iter()
401                    .zip(b_f.iter())
402                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
403        }
404        _ => false,
405    }
406}
407
408pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
409    match (a, b) {
410        (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
411        (VmValue::Float(x), VmValue::Float(y)) => {
412            if x < y {
413                -1
414            } else if x > y {
415                1
416            } else {
417                0
418            }
419        }
420        (VmValue::Int(x), VmValue::Float(y)) => {
421            let x = *x as f64;
422            if x < *y {
423                -1
424            } else if x > *y {
425                1
426            } else {
427                0
428            }
429        }
430        (VmValue::Float(x), VmValue::Int(y)) => {
431            let y = *y as f64;
432            if *x < y {
433                -1
434            } else if *x > y {
435                1
436            } else {
437                0
438            }
439        }
440        (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
441        _ => 0,
442    }
443}