Skip to main content

harn_vm/
value.rs

1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9/// An async builtin function for the VM.
10pub type VmAsyncBuiltinFn = Rc<
11    dyn Fn(
12        Vec<VmValue>,
13    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16/// The raw join handle type for spawned tasks.
17pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19/// A spawned async task handle with cancellation support.
20pub struct VmTaskHandle {
21    pub handle: VmJoinHandle,
22    /// Cooperative cancellation token. Set to true to request graceful shutdown.
23    pub cancel_token: Arc<AtomicBool>,
24}
25
26/// A channel handle for the VM (uses tokio mpsc).
27#[derive(Debug, Clone)]
28pub struct VmChannelHandle {
29    pub name: String,
30    pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
31    pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
32    pub closed: Arc<AtomicBool>,
33}
34
35/// An atomic integer handle for the VM.
36#[derive(Debug, Clone)]
37pub struct VmAtomicHandle {
38    pub value: Arc<AtomicI64>,
39}
40
41/// A generator object: lazily produces values via yield.
42/// The generator body runs as a spawned task that sends values through a channel.
43#[derive(Debug, Clone)]
44pub struct VmGenerator {
45    /// Whether the generator has finished (returned or exhausted).
46    pub done: Rc<std::cell::Cell<bool>>,
47    /// Receiver end of the yield channel (generator sends values here).
48    /// Wrapped in a shared async mutex so recv() can be called without holding
49    /// a RefCell borrow across await points.
50    pub receiver: Rc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
51}
52
53/// VM runtime value.
54#[derive(Debug, Clone)]
55pub enum VmValue {
56    Int(i64),
57    Float(f64),
58    String(Rc<str>),
59    Bool(bool),
60    Nil,
61    List(Rc<Vec<VmValue>>),
62    Dict(Rc<BTreeMap<String, VmValue>>),
63    Closure(Rc<VmClosure>),
64    Duration(u64),
65    EnumVariant {
66        enum_name: String,
67        variant: String,
68        fields: Vec<VmValue>,
69    },
70    StructInstance {
71        struct_name: String,
72        fields: BTreeMap<String, VmValue>,
73    },
74    TaskHandle(String),
75    Channel(VmChannelHandle),
76    Atomic(VmAtomicHandle),
77    McpClient(VmMcpClientHandle),
78    Set(Rc<Vec<VmValue>>),
79    Generator(VmGenerator),
80}
81
82/// A compiled closure value.
83#[derive(Debug, Clone)]
84pub struct VmClosure {
85    pub func: CompiledFunction,
86    pub env: VmEnv,
87}
88
89/// VM environment for variable storage.
90#[derive(Debug, Clone)]
91pub struct VmEnv {
92    pub(crate) scopes: Vec<Scope>,
93}
94
95#[derive(Debug, Clone)]
96pub(crate) struct Scope {
97    pub(crate) vars: BTreeMap<String, (VmValue, bool)>, // (value, mutable)
98}
99
100impl Default for VmEnv {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106impl VmEnv {
107    pub fn new() -> Self {
108        Self {
109            scopes: vec![Scope {
110                vars: BTreeMap::new(),
111            }],
112        }
113    }
114
115    pub fn push_scope(&mut self) {
116        self.scopes.push(Scope {
117            vars: BTreeMap::new(),
118        });
119    }
120
121    pub fn get(&self, name: &str) -> Option<VmValue> {
122        for scope in self.scopes.iter().rev() {
123            if let Some((val, _)) = scope.vars.get(name) {
124                return Some(val.clone());
125            }
126        }
127        None
128    }
129
130    pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) {
131        if let Some(scope) = self.scopes.last_mut() {
132            scope.vars.insert(name.to_string(), (value, mutable));
133        }
134    }
135
136    pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
137        let mut vars = BTreeMap::new();
138        for scope in &self.scopes {
139            for (name, (value, _)) in &scope.vars {
140                vars.insert(name.clone(), value.clone());
141            }
142        }
143        vars
144    }
145
146    pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
147        for scope in self.scopes.iter_mut().rev() {
148            if let Some((_, mutable)) = scope.vars.get(name) {
149                if !mutable {
150                    return Err(VmError::ImmutableAssignment(name.to_string()));
151                }
152                scope.vars.insert(name.to_string(), (value, true));
153                return Ok(());
154            }
155        }
156        Err(VmError::UndefinedVariable(name.to_string()))
157    }
158}
159
160/// VM runtime errors.
161/// Compute Levenshtein edit distance between two strings.
162fn levenshtein(a: &str, b: &str) -> usize {
163    let a: Vec<char> = a.chars().collect();
164    let b: Vec<char> = b.chars().collect();
165    let (m, n) = (a.len(), b.len());
166    let mut prev = (0..=n).collect::<Vec<_>>();
167    let mut curr = vec![0; n + 1];
168    for i in 1..=m {
169        curr[0] = i;
170        for j in 1..=n {
171            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
172            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
173        }
174        std::mem::swap(&mut prev, &mut curr);
175    }
176    prev[n]
177}
178
179/// Find the closest match from a list of candidates using Levenshtein distance.
180/// Returns `Some(suggestion)` if a candidate is within `max_dist` edits.
181pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
182    let max_dist = match name.len() {
183        0..=2 => 1,
184        3..=5 => 2,
185        _ => 3,
186    };
187    candidates
188        .filter(|c| *c != name && !c.starts_with("__"))
189        .map(|c| (c, levenshtein(name, c)))
190        .filter(|(_, d)| *d <= max_dist)
191        // Prefer smallest distance, then closest length to original, then alphabetical
192        .min_by(|(a, da), (b, db)| {
193            da.cmp(db)
194                .then_with(|| {
195                    let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
196                    let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
197                    a_diff.cmp(&b_diff)
198                })
199                .then_with(|| a.cmp(b))
200        })
201        .map(|(c, _)| c.to_string())
202}
203
204#[derive(Debug, Clone)]
205pub enum VmError {
206    StackUnderflow,
207    StackOverflow,
208    UndefinedVariable(String),
209    UndefinedBuiltin(String),
210    ImmutableAssignment(String),
211    TypeError(String),
212    Runtime(String),
213    DivisionByZero,
214    Thrown(VmValue),
215    /// Thrown with error category for structured error handling.
216    CategorizedError {
217        message: String,
218        category: ErrorCategory,
219    },
220    Return(VmValue),
221    InvalidInstruction(u8),
222}
223
224/// Error categories for structured error handling in agent orchestration.
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum ErrorCategory {
227    /// Network/connection timeout
228    Timeout,
229    /// Authentication/authorization failure
230    Auth,
231    /// Rate limit exceeded
232    RateLimit,
233    /// Tool execution failure
234    ToolError,
235    /// Operation was cancelled
236    Cancelled,
237    /// Resource not found
238    NotFound,
239    /// Circuit breaker is open
240    CircuitOpen,
241    /// Generic/unclassified error
242    Generic,
243}
244
245impl ErrorCategory {
246    pub fn as_str(&self) -> &'static str {
247        match self {
248            ErrorCategory::Timeout => "timeout",
249            ErrorCategory::Auth => "auth",
250            ErrorCategory::RateLimit => "rate_limit",
251            ErrorCategory::ToolError => "tool_error",
252            ErrorCategory::Cancelled => "cancelled",
253            ErrorCategory::NotFound => "not_found",
254            ErrorCategory::CircuitOpen => "circuit_open",
255            ErrorCategory::Generic => "generic",
256        }
257    }
258
259    pub fn parse(s: &str) -> Self {
260        match s {
261            "timeout" => ErrorCategory::Timeout,
262            "auth" => ErrorCategory::Auth,
263            "rate_limit" => ErrorCategory::RateLimit,
264            "tool_error" => ErrorCategory::ToolError,
265            "cancelled" => ErrorCategory::Cancelled,
266            "not_found" => ErrorCategory::NotFound,
267            "circuit_open" => ErrorCategory::CircuitOpen,
268            _ => ErrorCategory::Generic,
269        }
270    }
271}
272
273/// Create a categorized error conveniently.
274pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
275    VmError::CategorizedError {
276        message: message.into(),
277        category,
278    }
279}
280
281/// Extract error category from a VmError.
282///
283/// Classification priority:
284/// 1. Explicit CategorizedError variant (set by throw_error or internal code)
285/// 2. Thrown dict with a "category" field (user-created structured errors)
286/// 3. HTTP status code extraction (standard, unambiguous)
287/// 4. Deadline exceeded (VM-internal)
288/// 5. Fallback to Generic
289pub fn error_to_category(err: &VmError) -> ErrorCategory {
290    match err {
291        VmError::CategorizedError { category, .. } => category.clone(),
292        VmError::Thrown(VmValue::Dict(d)) => d
293            .get("category")
294            .map(|v| ErrorCategory::parse(&v.display()))
295            .unwrap_or(ErrorCategory::Generic),
296        VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
297        VmError::Runtime(msg) => classify_error_message(msg),
298        _ => ErrorCategory::Generic,
299    }
300}
301
302/// Classify an error message using HTTP status codes and well-known patterns.
303/// Prefers unambiguous signals (status codes) over substring heuristics.
304fn classify_error_message(msg: &str) -> ErrorCategory {
305    // 1. HTTP status codes — most reliable signal
306    if let Some(cat) = classify_by_http_status(msg) {
307        return cat;
308    }
309    // 2. Well-known error identifiers from major APIs
310    //    (Anthropic, OpenAI, and standard HTTP patterns)
311    if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
312        return ErrorCategory::Timeout;
313    }
314    if msg.contains("overloaded_error") || msg.contains("api_error") {
315        // Anthropic-specific error types
316        return ErrorCategory::RateLimit;
317    }
318    if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
319        // OpenAI-specific error types
320        return ErrorCategory::RateLimit;
321    }
322    if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
323        return ErrorCategory::Auth;
324    }
325    if msg.contains("not_found_error") || msg.contains("model_not_found") {
326        return ErrorCategory::NotFound;
327    }
328    if msg.contains("circuit_open") {
329        return ErrorCategory::CircuitOpen;
330    }
331    ErrorCategory::Generic
332}
333
334/// Classify errors by HTTP status code if one appears in the message.
335/// This is the most reliable classification method since status codes
336/// are standardized (RFC 9110) and unambiguous.
337fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
338    // Extract 3-digit HTTP status codes from common patterns:
339    // "HTTP 429", "status 429", "429 Too Many", "error: 401"
340    for code in extract_http_status_codes(msg) {
341        return Some(match code {
342            401 | 403 => ErrorCategory::Auth,
343            404 | 410 => ErrorCategory::NotFound,
344            408 | 504 | 522 | 524 => ErrorCategory::Timeout,
345            429 | 503 => ErrorCategory::RateLimit,
346            _ => continue,
347        });
348    }
349    None
350}
351
352/// Extract plausible HTTP status codes from an error message.
353fn extract_http_status_codes(msg: &str) -> Vec<u16> {
354    let mut codes = Vec::new();
355    let bytes = msg.as_bytes();
356    for i in 0..bytes.len().saturating_sub(2) {
357        // Look for 3-digit sequences in the 100-599 range
358        if bytes[i].is_ascii_digit()
359            && bytes[i + 1].is_ascii_digit()
360            && bytes[i + 2].is_ascii_digit()
361        {
362            // Ensure it's not part of a longer number
363            let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
364            let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
365            if before_ok && after_ok {
366                if let Ok(code) = msg[i..i + 3].parse::<u16>() {
367                    if (400..=599).contains(&code) {
368                        codes.push(code);
369                    }
370                }
371            }
372        }
373    }
374    codes
375}
376
377impl std::fmt::Display for VmError {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        match self {
380            VmError::StackUnderflow => write!(f, "Stack underflow"),
381            VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
382            VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
383            VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
384            VmError::ImmutableAssignment(n) => {
385                write!(f, "Cannot assign to immutable binding: {n}")
386            }
387            VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
388            VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
389            VmError::DivisionByZero => write!(f, "Division by zero"),
390            VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
391            VmError::CategorizedError { message, category } => {
392                write!(f, "Error [{}]: {}", category.as_str(), message)
393            }
394            VmError::Return(_) => write!(f, "Return from function"),
395            VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
396        }
397    }
398}
399
400impl std::error::Error for VmError {}
401
402impl VmValue {
403    pub fn is_truthy(&self) -> bool {
404        match self {
405            VmValue::Bool(b) => *b,
406            VmValue::Nil => false,
407            VmValue::Int(n) => *n != 0,
408            VmValue::Float(n) => *n != 0.0,
409            VmValue::String(s) => !s.is_empty(),
410            VmValue::List(l) => !l.is_empty(),
411            VmValue::Dict(d) => !d.is_empty(),
412            VmValue::Closure(_) => true,
413            VmValue::Duration(ms) => *ms > 0,
414            VmValue::EnumVariant { .. } => true,
415            VmValue::StructInstance { .. } => true,
416            VmValue::TaskHandle(_) => true,
417            VmValue::Channel(_) => true,
418            VmValue::Atomic(_) => true,
419            VmValue::McpClient(_) => true,
420            VmValue::Set(s) => !s.is_empty(),
421            VmValue::Generator(_) => true,
422        }
423    }
424
425    pub fn type_name(&self) -> &'static str {
426        match self {
427            VmValue::String(_) => "string",
428            VmValue::Int(_) => "int",
429            VmValue::Float(_) => "float",
430            VmValue::Bool(_) => "bool",
431            VmValue::Nil => "nil",
432            VmValue::List(_) => "list",
433            VmValue::Dict(_) => "dict",
434            VmValue::Closure(_) => "closure",
435            VmValue::Duration(_) => "duration",
436            VmValue::EnumVariant { .. } => "enum",
437            VmValue::StructInstance { .. } => "struct",
438            VmValue::TaskHandle(_) => "task_handle",
439            VmValue::Channel(_) => "channel",
440            VmValue::Atomic(_) => "atomic",
441            VmValue::McpClient(_) => "mcp_client",
442            VmValue::Set(_) => "set",
443            VmValue::Generator(_) => "generator",
444        }
445    }
446
447    pub fn display(&self) -> String {
448        match self {
449            VmValue::Int(n) => n.to_string(),
450            VmValue::Float(n) => {
451                if *n == (*n as i64) as f64 && n.abs() < 1e15 {
452                    format!("{:.1}", n)
453                } else {
454                    n.to_string()
455                }
456            }
457            VmValue::String(s) => s.to_string(),
458            VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
459            VmValue::Nil => "nil".to_string(),
460            VmValue::List(items) => {
461                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
462                format!("[{}]", inner.join(", "))
463            }
464            VmValue::Dict(map) => {
465                let inner: Vec<String> = map
466                    .iter()
467                    .map(|(k, v)| format!("{k}: {}", v.display()))
468                    .collect();
469                format!("{{{}}}", inner.join(", "))
470            }
471            VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
472            VmValue::Duration(ms) => {
473                if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
474                    format!("{}h", ms / 3_600_000)
475                } else if *ms >= 60_000 && ms % 60_000 == 0 {
476                    format!("{}m", ms / 60_000)
477                } else if *ms >= 1000 && ms % 1000 == 0 {
478                    format!("{}s", ms / 1000)
479                } else {
480                    format!("{}ms", ms)
481                }
482            }
483            VmValue::EnumVariant {
484                enum_name,
485                variant,
486                fields,
487            } => {
488                if fields.is_empty() {
489                    format!("{enum_name}.{variant}")
490                } else {
491                    let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
492                    format!("{enum_name}.{variant}({})", inner.join(", "))
493                }
494            }
495            VmValue::StructInstance {
496                struct_name,
497                fields,
498            } => {
499                let inner: Vec<String> = fields
500                    .iter()
501                    .map(|(k, v)| format!("{k}: {}", v.display()))
502                    .collect();
503                format!("{struct_name} {{{}}}", inner.join(", "))
504            }
505            VmValue::TaskHandle(id) => format!("<task:{id}>"),
506            VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
507            VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
508            VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
509            VmValue::Set(items) => {
510                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
511                format!("set({})", inner.join(", "))
512            }
513            VmValue::Generator(g) => {
514                if g.done.get() {
515                    "<generator (done)>".to_string()
516                } else {
517                    "<generator>".to_string()
518                }
519            }
520        }
521    }
522
523    /// Get the value as a BTreeMap reference, if it's a Dict.
524    pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
525        if let VmValue::Dict(d) = self {
526            Some(d)
527        } else {
528            None
529        }
530    }
531
532    pub fn as_int(&self) -> Option<i64> {
533        if let VmValue::Int(n) = self {
534            Some(*n)
535        } else {
536            None
537        }
538    }
539}
540
541/// Sync builtin function for the VM.
542pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
543
544pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
545    match (a, b) {
546        (VmValue::Int(x), VmValue::Int(y)) => x == y,
547        (VmValue::Float(x), VmValue::Float(y)) => x == y,
548        (VmValue::String(x), VmValue::String(y)) => x == y,
549        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
550        (VmValue::Nil, VmValue::Nil) => true,
551        (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
552        (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
553        (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
554        (VmValue::Channel(_), VmValue::Channel(_)) => false, // channels are never equal
555        (VmValue::Atomic(a), VmValue::Atomic(b)) => {
556            a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
557        }
558        (VmValue::List(a), VmValue::List(b)) => {
559            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
560        }
561        (VmValue::Dict(a), VmValue::Dict(b)) => {
562            a.len() == b.len()
563                && a.iter()
564                    .zip(b.iter())
565                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
566        }
567        (
568            VmValue::EnumVariant {
569                enum_name: a_e,
570                variant: a_v,
571                fields: a_f,
572            },
573            VmValue::EnumVariant {
574                enum_name: b_e,
575                variant: b_v,
576                fields: b_f,
577            },
578        ) => {
579            a_e == b_e
580                && a_v == b_v
581                && a_f.len() == b_f.len()
582                && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
583        }
584        (
585            VmValue::StructInstance {
586                struct_name: a_s,
587                fields: a_f,
588            },
589            VmValue::StructInstance {
590                struct_name: b_s,
591                fields: b_f,
592            },
593        ) => {
594            a_s == b_s
595                && a_f.len() == b_f.len()
596                && a_f
597                    .iter()
598                    .zip(b_f.iter())
599                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
600        }
601        (VmValue::Set(a), VmValue::Set(b)) => {
602            a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
603        }
604        (VmValue::Generator(_), VmValue::Generator(_)) => false, // generators are never equal
605        _ => false,
606    }
607}
608
609pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
610    match (a, b) {
611        (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
612        (VmValue::Float(x), VmValue::Float(y)) => {
613            if x < y {
614                -1
615            } else if x > y {
616                1
617            } else {
618                0
619            }
620        }
621        (VmValue::Int(x), VmValue::Float(y)) => {
622            let x = *x as f64;
623            if x < *y {
624                -1
625            } else if x > *y {
626                1
627            } else {
628                0
629            }
630        }
631        (VmValue::Float(x), VmValue::Int(y)) => {
632            let y = *y as f64;
633            if *x < y {
634                -1
635            } else if *x > y {
636                1
637            } else {
638                0
639            }
640        }
641        (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
642        _ => 0,
643    }
644}