Skip to main content

harn_vm/
value.rs

1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9/// An async builtin function for the VM.
10pub type VmAsyncBuiltinFn = Rc<
11    dyn Fn(
12        Vec<VmValue>,
13    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16/// A spawned async task handle.
17pub type VmTaskHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19/// A channel handle for the VM (uses tokio mpsc).
20#[derive(Debug, Clone)]
21pub struct VmChannelHandle {
22    pub name: String,
23    pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
24    pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
25    pub closed: Arc<AtomicBool>,
26}
27
28/// An atomic integer handle for the VM.
29#[derive(Debug, Clone)]
30pub struct VmAtomicHandle {
31    pub value: Arc<AtomicI64>,
32}
33
34/// VM runtime value.
35#[derive(Debug, Clone)]
36pub enum VmValue {
37    Int(i64),
38    Float(f64),
39    String(Rc<str>),
40    Bool(bool),
41    Nil,
42    List(Rc<Vec<VmValue>>),
43    Dict(Rc<BTreeMap<String, VmValue>>),
44    Closure(Rc<VmClosure>),
45    Duration(u64),
46    EnumVariant {
47        enum_name: String,
48        variant: String,
49        fields: Vec<VmValue>,
50    },
51    StructInstance {
52        struct_name: String,
53        fields: BTreeMap<String, VmValue>,
54    },
55    TaskHandle(String),
56    Channel(VmChannelHandle),
57    Atomic(VmAtomicHandle),
58    McpClient(VmMcpClientHandle),
59    Set(Rc<Vec<VmValue>>),
60}
61
62/// A compiled closure value.
63#[derive(Debug, Clone)]
64pub struct VmClosure {
65    pub func: CompiledFunction,
66    pub env: VmEnv,
67}
68
69/// VM environment for variable storage.
70#[derive(Debug, Clone)]
71pub struct VmEnv {
72    pub(crate) scopes: Vec<Scope>,
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct Scope {
77    pub(crate) vars: BTreeMap<String, (VmValue, bool)>, // (value, mutable)
78}
79
80impl Default for VmEnv {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl VmEnv {
87    pub fn new() -> Self {
88        Self {
89            scopes: vec![Scope {
90                vars: BTreeMap::new(),
91            }],
92        }
93    }
94
95    pub fn push_scope(&mut self) {
96        self.scopes.push(Scope {
97            vars: BTreeMap::new(),
98        });
99    }
100
101    pub fn get(&self, name: &str) -> Option<VmValue> {
102        for scope in self.scopes.iter().rev() {
103            if let Some((val, _)) = scope.vars.get(name) {
104                return Some(val.clone());
105            }
106        }
107        None
108    }
109
110    pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) {
111        if let Some(scope) = self.scopes.last_mut() {
112            scope.vars.insert(name.to_string(), (value, mutable));
113        }
114    }
115
116    pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
117        let mut vars = BTreeMap::new();
118        for scope in &self.scopes {
119            for (name, (value, _)) in &scope.vars {
120                vars.insert(name.clone(), value.clone());
121            }
122        }
123        vars
124    }
125
126    pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
127        for scope in self.scopes.iter_mut().rev() {
128            if let Some((_, mutable)) = scope.vars.get(name) {
129                if !mutable {
130                    return Err(VmError::ImmutableAssignment(name.to_string()));
131                }
132                scope.vars.insert(name.to_string(), (value, true));
133                return Ok(());
134            }
135        }
136        Err(VmError::UndefinedVariable(name.to_string()))
137    }
138}
139
140/// VM runtime errors.
141/// Compute Levenshtein edit distance between two strings.
142fn levenshtein(a: &str, b: &str) -> usize {
143    let a: Vec<char> = a.chars().collect();
144    let b: Vec<char> = b.chars().collect();
145    let (m, n) = (a.len(), b.len());
146    let mut prev = (0..=n).collect::<Vec<_>>();
147    let mut curr = vec![0; n + 1];
148    for i in 1..=m {
149        curr[0] = i;
150        for j in 1..=n {
151            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
152            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
153        }
154        std::mem::swap(&mut prev, &mut curr);
155    }
156    prev[n]
157}
158
159/// Find the closest match from a list of candidates using Levenshtein distance.
160/// Returns `Some(suggestion)` if a candidate is within `max_dist` edits.
161pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
162    let max_dist = match name.len() {
163        0..=2 => 1,
164        3..=5 => 2,
165        _ => 3,
166    };
167    candidates
168        .filter(|c| *c != name && !c.starts_with("__"))
169        .map(|c| (c, levenshtein(name, c)))
170        .filter(|(_, d)| *d <= max_dist)
171        // Prefer smallest distance, then closest length to original, then alphabetical
172        .min_by(|(a, da), (b, db)| {
173            da.cmp(db)
174                .then_with(|| {
175                    let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
176                    let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
177                    a_diff.cmp(&b_diff)
178                })
179                .then_with(|| a.cmp(b))
180        })
181        .map(|(c, _)| c.to_string())
182}
183
184#[derive(Debug, Clone)]
185pub enum VmError {
186    StackUnderflow,
187    StackOverflow,
188    UndefinedVariable(String),
189    UndefinedBuiltin(String),
190    ImmutableAssignment(String),
191    TypeError(String),
192    Runtime(String),
193    DivisionByZero,
194    Thrown(VmValue),
195    Return(VmValue),
196    InvalidInstruction(u8),
197}
198
199impl std::fmt::Display for VmError {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        match self {
202            VmError::StackUnderflow => write!(f, "Stack underflow"),
203            VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
204            VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
205            VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
206            VmError::ImmutableAssignment(n) => {
207                write!(f, "Cannot assign to immutable binding: {n}")
208            }
209            VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
210            VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
211            VmError::DivisionByZero => write!(f, "Division by zero"),
212            VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
213            VmError::Return(_) => write!(f, "Return from function"),
214            VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
215        }
216    }
217}
218
219impl std::error::Error for VmError {}
220
221impl VmValue {
222    pub fn is_truthy(&self) -> bool {
223        match self {
224            VmValue::Bool(b) => *b,
225            VmValue::Nil => false,
226            VmValue::Int(n) => *n != 0,
227            VmValue::Float(n) => *n != 0.0,
228            VmValue::String(s) => !s.is_empty(),
229            VmValue::List(l) => !l.is_empty(),
230            VmValue::Dict(d) => !d.is_empty(),
231            VmValue::Closure(_) => true,
232            VmValue::Duration(ms) => *ms > 0,
233            VmValue::EnumVariant { .. } => true,
234            VmValue::StructInstance { .. } => true,
235            VmValue::TaskHandle(_) => true,
236            VmValue::Channel(_) => true,
237            VmValue::Atomic(_) => true,
238            VmValue::McpClient(_) => true,
239            VmValue::Set(s) => !s.is_empty(),
240        }
241    }
242
243    pub fn type_name(&self) -> &'static str {
244        match self {
245            VmValue::String(_) => "string",
246            VmValue::Int(_) => "int",
247            VmValue::Float(_) => "float",
248            VmValue::Bool(_) => "bool",
249            VmValue::Nil => "nil",
250            VmValue::List(_) => "list",
251            VmValue::Dict(_) => "dict",
252            VmValue::Closure(_) => "closure",
253            VmValue::Duration(_) => "duration",
254            VmValue::EnumVariant { .. } => "enum",
255            VmValue::StructInstance { .. } => "struct",
256            VmValue::TaskHandle(_) => "task_handle",
257            VmValue::Channel(_) => "channel",
258            VmValue::Atomic(_) => "atomic",
259            VmValue::McpClient(_) => "mcp_client",
260            VmValue::Set(_) => "set",
261        }
262    }
263
264    pub fn display(&self) -> String {
265        match self {
266            VmValue::Int(n) => n.to_string(),
267            VmValue::Float(n) => {
268                if *n == (*n as i64) as f64 && n.abs() < 1e15 {
269                    format!("{:.1}", n)
270                } else {
271                    n.to_string()
272                }
273            }
274            VmValue::String(s) => s.to_string(),
275            VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
276            VmValue::Nil => "nil".to_string(),
277            VmValue::List(items) => {
278                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
279                format!("[{}]", inner.join(", "))
280            }
281            VmValue::Dict(map) => {
282                let inner: Vec<String> = map
283                    .iter()
284                    .map(|(k, v)| format!("{k}: {}", v.display()))
285                    .collect();
286                format!("{{{}}}", inner.join(", "))
287            }
288            VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
289            VmValue::Duration(ms) => {
290                if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
291                    format!("{}h", ms / 3_600_000)
292                } else if *ms >= 60_000 && ms % 60_000 == 0 {
293                    format!("{}m", ms / 60_000)
294                } else if *ms >= 1000 && ms % 1000 == 0 {
295                    format!("{}s", ms / 1000)
296                } else {
297                    format!("{}ms", ms)
298                }
299            }
300            VmValue::EnumVariant {
301                enum_name,
302                variant,
303                fields,
304            } => {
305                if fields.is_empty() {
306                    format!("{enum_name}.{variant}")
307                } else {
308                    let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
309                    format!("{enum_name}.{variant}({})", inner.join(", "))
310                }
311            }
312            VmValue::StructInstance {
313                struct_name,
314                fields,
315            } => {
316                let inner: Vec<String> = fields
317                    .iter()
318                    .map(|(k, v)| format!("{k}: {}", v.display()))
319                    .collect();
320                format!("{struct_name} {{{}}}", inner.join(", "))
321            }
322            VmValue::TaskHandle(id) => format!("<task:{id}>"),
323            VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
324            VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
325            VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
326            VmValue::Set(items) => {
327                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
328                format!("set({})", inner.join(", "))
329            }
330        }
331    }
332
333    /// Get the value as a BTreeMap reference, if it's a Dict.
334    pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
335        if let VmValue::Dict(d) = self {
336            Some(d)
337        } else {
338            None
339        }
340    }
341
342    pub fn as_int(&self) -> Option<i64> {
343        if let VmValue::Int(n) = self {
344            Some(*n)
345        } else {
346            None
347        }
348    }
349}
350
351/// Sync builtin function for the VM.
352pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
353
354pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
355    match (a, b) {
356        (VmValue::Int(x), VmValue::Int(y)) => x == y,
357        (VmValue::Float(x), VmValue::Float(y)) => x == y,
358        (VmValue::String(x), VmValue::String(y)) => x == y,
359        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
360        (VmValue::Nil, VmValue::Nil) => true,
361        (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
362        (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
363        (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
364        (VmValue::Channel(_), VmValue::Channel(_)) => false, // channels are never equal
365        (VmValue::Atomic(a), VmValue::Atomic(b)) => {
366            a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
367        }
368        (VmValue::List(a), VmValue::List(b)) => {
369            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
370        }
371        (VmValue::Dict(a), VmValue::Dict(b)) => {
372            a.len() == b.len()
373                && a.iter()
374                    .zip(b.iter())
375                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
376        }
377        (
378            VmValue::EnumVariant {
379                enum_name: a_e,
380                variant: a_v,
381                fields: a_f,
382            },
383            VmValue::EnumVariant {
384                enum_name: b_e,
385                variant: b_v,
386                fields: b_f,
387            },
388        ) => {
389            a_e == b_e
390                && a_v == b_v
391                && a_f.len() == b_f.len()
392                && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
393        }
394        (
395            VmValue::StructInstance {
396                struct_name: a_s,
397                fields: a_f,
398            },
399            VmValue::StructInstance {
400                struct_name: b_s,
401                fields: b_f,
402            },
403        ) => {
404            a_s == b_s
405                && a_f.len() == b_f.len()
406                && a_f
407                    .iter()
408                    .zip(b_f.iter())
409                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
410        }
411        (VmValue::Set(a), VmValue::Set(b)) => {
412            a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
413        }
414        _ => false,
415    }
416}
417
418pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
419    match (a, b) {
420        (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
421        (VmValue::Float(x), VmValue::Float(y)) => {
422            if x < y {
423                -1
424            } else if x > y {
425                1
426            } else {
427                0
428            }
429        }
430        (VmValue::Int(x), VmValue::Float(y)) => {
431            let x = *x as f64;
432            if x < *y {
433                -1
434            } else if x > *y {
435                1
436            } else {
437                0
438            }
439        }
440        (VmValue::Float(x), VmValue::Int(y)) => {
441            let y = *y as f64;
442            if *x < y {
443                -1
444            } else if *x > y {
445                1
446            } else {
447                0
448            }
449        }
450        (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
451        _ => 0,
452    }
453}