Skip to main content

harn_vm/
value.rs

1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9/// An async builtin function for the VM.
10pub type VmAsyncBuiltinFn = Rc<
11    dyn Fn(
12        Vec<VmValue>,
13    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16/// A spawned async task handle.
17pub type VmTaskHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19/// A channel handle for the VM (uses tokio mpsc).
20#[derive(Debug, Clone)]
21pub struct VmChannelHandle {
22    pub name: String,
23    pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
24    pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
25    pub closed: Arc<AtomicBool>,
26}
27
28/// An atomic integer handle for the VM.
29#[derive(Debug, Clone)]
30pub struct VmAtomicHandle {
31    pub value: Arc<AtomicI64>,
32}
33
34/// VM runtime value.
35#[derive(Debug, Clone)]
36pub enum VmValue {
37    Int(i64),
38    Float(f64),
39    String(Rc<str>),
40    Bool(bool),
41    Nil,
42    List(Rc<Vec<VmValue>>),
43    Dict(Rc<BTreeMap<String, VmValue>>),
44    Closure(Rc<VmClosure>),
45    Duration(u64),
46    EnumVariant {
47        enum_name: String,
48        variant: String,
49        fields: Vec<VmValue>,
50    },
51    StructInstance {
52        struct_name: String,
53        fields: BTreeMap<String, VmValue>,
54    },
55    TaskHandle(String),
56    Channel(VmChannelHandle),
57    Atomic(VmAtomicHandle),
58    McpClient(VmMcpClientHandle),
59}
60
61/// A compiled closure value.
62#[derive(Debug, Clone)]
63pub struct VmClosure {
64    pub func: CompiledFunction,
65    pub env: VmEnv,
66}
67
68/// VM environment for variable storage.
69#[derive(Debug, Clone)]
70pub struct VmEnv {
71    pub(crate) scopes: Vec<Scope>,
72}
73
74#[derive(Debug, Clone)]
75pub(crate) struct Scope {
76    pub(crate) vars: BTreeMap<String, (VmValue, bool)>, // (value, mutable)
77}
78
79impl Default for VmEnv {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl VmEnv {
86    pub fn new() -> Self {
87        Self {
88            scopes: vec![Scope {
89                vars: BTreeMap::new(),
90            }],
91        }
92    }
93
94    pub fn push_scope(&mut self) {
95        self.scopes.push(Scope {
96            vars: BTreeMap::new(),
97        });
98    }
99
100    pub fn get(&self, name: &str) -> Option<VmValue> {
101        for scope in self.scopes.iter().rev() {
102            if let Some((val, _)) = scope.vars.get(name) {
103                return Some(val.clone());
104            }
105        }
106        None
107    }
108
109    pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) {
110        if let Some(scope) = self.scopes.last_mut() {
111            scope.vars.insert(name.to_string(), (value, mutable));
112        }
113    }
114
115    pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
116        let mut vars = BTreeMap::new();
117        for scope in &self.scopes {
118            for (name, (value, _)) in &scope.vars {
119                vars.insert(name.clone(), value.clone());
120            }
121        }
122        vars
123    }
124
125    pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
126        for scope in self.scopes.iter_mut().rev() {
127            if let Some((_, mutable)) = scope.vars.get(name) {
128                if !mutable {
129                    return Err(VmError::ImmutableAssignment(name.to_string()));
130                }
131                scope.vars.insert(name.to_string(), (value, true));
132                return Ok(());
133            }
134        }
135        Err(VmError::UndefinedVariable(name.to_string()))
136    }
137}
138
139/// VM runtime errors.
140#[derive(Debug, Clone)]
141pub enum VmError {
142    StackUnderflow,
143    StackOverflow,
144    UndefinedVariable(String),
145    UndefinedBuiltin(String),
146    ImmutableAssignment(String),
147    TypeError(String),
148    Runtime(String),
149    DivisionByZero,
150    Thrown(VmValue),
151    Return(VmValue),
152    InvalidInstruction(u8),
153}
154
155impl std::fmt::Display for VmError {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        match self {
158            VmError::StackUnderflow => write!(f, "Stack underflow"),
159            VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
160            VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
161            VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
162            VmError::ImmutableAssignment(n) => {
163                write!(f, "Cannot assign to immutable binding: {n}")
164            }
165            VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
166            VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
167            VmError::DivisionByZero => write!(f, "Division by zero"),
168            VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
169            VmError::Return(_) => write!(f, "Return from function"),
170            VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
171        }
172    }
173}
174
175impl std::error::Error for VmError {}
176
177impl VmValue {
178    pub fn is_truthy(&self) -> bool {
179        match self {
180            VmValue::Bool(b) => *b,
181            VmValue::Nil => false,
182            VmValue::Int(n) => *n != 0,
183            VmValue::Float(n) => *n != 0.0,
184            VmValue::String(s) => !s.is_empty(),
185            VmValue::List(l) => !l.is_empty(),
186            VmValue::Dict(d) => !d.is_empty(),
187            VmValue::Closure(_) => true,
188            VmValue::Duration(ms) => *ms > 0,
189            VmValue::EnumVariant { .. } => true,
190            VmValue::StructInstance { .. } => true,
191            VmValue::TaskHandle(_) => true,
192            VmValue::Channel(_) => true,
193            VmValue::Atomic(_) => true,
194            VmValue::McpClient(_) => true,
195        }
196    }
197
198    pub fn type_name(&self) -> &'static str {
199        match self {
200            VmValue::String(_) => "string",
201            VmValue::Int(_) => "int",
202            VmValue::Float(_) => "float",
203            VmValue::Bool(_) => "bool",
204            VmValue::Nil => "nil",
205            VmValue::List(_) => "list",
206            VmValue::Dict(_) => "dict",
207            VmValue::Closure(_) => "closure",
208            VmValue::Duration(_) => "duration",
209            VmValue::EnumVariant { .. } => "enum",
210            VmValue::StructInstance { .. } => "struct",
211            VmValue::TaskHandle(_) => "task_handle",
212            VmValue::Channel(_) => "channel",
213            VmValue::Atomic(_) => "atomic",
214            VmValue::McpClient(_) => "mcp_client",
215        }
216    }
217
218    pub fn display(&self) -> String {
219        match self {
220            VmValue::Int(n) => n.to_string(),
221            VmValue::Float(n) => {
222                if *n == (*n as i64) as f64 && n.abs() < 1e15 {
223                    format!("{:.1}", n)
224                } else {
225                    n.to_string()
226                }
227            }
228            VmValue::String(s) => s.to_string(),
229            VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
230            VmValue::Nil => "nil".to_string(),
231            VmValue::List(items) => {
232                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
233                format!("[{}]", inner.join(", "))
234            }
235            VmValue::Dict(map) => {
236                let inner: Vec<String> = map
237                    .iter()
238                    .map(|(k, v)| format!("{k}: {}", v.display()))
239                    .collect();
240                format!("{{{}}}", inner.join(", "))
241            }
242            VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
243            VmValue::Duration(ms) => {
244                if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
245                    format!("{}h", ms / 3_600_000)
246                } else if *ms >= 60_000 && ms % 60_000 == 0 {
247                    format!("{}m", ms / 60_000)
248                } else if *ms >= 1000 && ms % 1000 == 0 {
249                    format!("{}s", ms / 1000)
250                } else {
251                    format!("{}ms", ms)
252                }
253            }
254            VmValue::EnumVariant {
255                enum_name,
256                variant,
257                fields,
258            } => {
259                if fields.is_empty() {
260                    format!("{enum_name}.{variant}")
261                } else {
262                    let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
263                    format!("{enum_name}.{variant}({})", inner.join(", "))
264                }
265            }
266            VmValue::StructInstance {
267                struct_name,
268                fields,
269            } => {
270                let inner: Vec<String> = fields
271                    .iter()
272                    .map(|(k, v)| format!("{k}: {}", v.display()))
273                    .collect();
274                format!("{struct_name} {{{}}}", inner.join(", "))
275            }
276            VmValue::TaskHandle(id) => format!("<task:{id}>"),
277            VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
278            VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
279            VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
280        }
281    }
282
283    /// Get the value as a BTreeMap reference, if it's a Dict.
284    pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
285        if let VmValue::Dict(d) = self {
286            Some(d)
287        } else {
288            None
289        }
290    }
291
292    pub fn as_int(&self) -> Option<i64> {
293        if let VmValue::Int(n) = self {
294            Some(*n)
295        } else {
296            None
297        }
298    }
299}
300
301/// Sync builtin function for the VM.
302pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
303
304pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
305    match (a, b) {
306        (VmValue::Int(x), VmValue::Int(y)) => x == y,
307        (VmValue::Float(x), VmValue::Float(y)) => x == y,
308        (VmValue::String(x), VmValue::String(y)) => x == y,
309        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
310        (VmValue::Nil, VmValue::Nil) => true,
311        (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
312        (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
313        (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
314        (VmValue::Channel(_), VmValue::Channel(_)) => false, // channels are never equal
315        (VmValue::Atomic(a), VmValue::Atomic(b)) => {
316            a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
317        }
318        (VmValue::List(a), VmValue::List(b)) => {
319            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
320        }
321        (VmValue::Dict(a), VmValue::Dict(b)) => {
322            a.len() == b.len()
323                && a.iter()
324                    .zip(b.iter())
325                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
326        }
327        (
328            VmValue::EnumVariant {
329                enum_name: a_e,
330                variant: a_v,
331                fields: a_f,
332            },
333            VmValue::EnumVariant {
334                enum_name: b_e,
335                variant: b_v,
336                fields: b_f,
337            },
338        ) => {
339            a_e == b_e
340                && a_v == b_v
341                && a_f.len() == b_f.len()
342                && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
343        }
344        (
345            VmValue::StructInstance {
346                struct_name: a_s,
347                fields: a_f,
348            },
349            VmValue::StructInstance {
350                struct_name: b_s,
351                fields: b_f,
352            },
353        ) => {
354            a_s == b_s
355                && a_f.len() == b_f.len()
356                && a_f
357                    .iter()
358                    .zip(b_f.iter())
359                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
360        }
361        _ => false,
362    }
363}
364
365pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
366    match (a, b) {
367        (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
368        (VmValue::Float(x), VmValue::Float(y)) => {
369            if x < y {
370                -1
371            } else if x > y {
372                1
373            } else {
374                0
375            }
376        }
377        (VmValue::Int(x), VmValue::Float(y)) => {
378            let x = *x as f64;
379            if x < *y {
380                -1
381            } else if x > *y {
382                1
383            } else {
384                0
385            }
386        }
387        (VmValue::Float(x), VmValue::Int(y)) => {
388            let y = *y as f64;
389            if *x < y {
390                -1
391            } else if *x > y {
392                1
393            } else {
394                0
395            }
396        }
397        (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
398        _ => 0,
399    }
400}