Skip to main content

harn_vm/
value.rs

1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9/// An async builtin function for the VM.
10pub type VmAsyncBuiltinFn = Rc<
11    dyn Fn(
12        Vec<VmValue>,
13    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16/// The raw join handle type for spawned tasks.
17pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19/// A spawned async task handle with cancellation support.
20pub struct VmTaskHandle {
21    pub handle: VmJoinHandle,
22    /// Cooperative cancellation token. Set to true to request graceful shutdown.
23    pub cancel_token: Arc<AtomicBool>,
24}
25
26/// A channel handle for the VM (uses tokio mpsc).
27#[derive(Debug, Clone)]
28pub struct VmChannelHandle {
29    pub name: String,
30    pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
31    pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
32    pub closed: Arc<AtomicBool>,
33}
34
35/// An atomic integer handle for the VM.
36#[derive(Debug, Clone)]
37pub struct VmAtomicHandle {
38    pub value: Arc<AtomicI64>,
39}
40
41/// A generator object: lazily produces values via yield.
42/// The generator body runs as a spawned task that sends values through a channel.
43#[derive(Debug, Clone)]
44pub struct VmGenerator {
45    /// Whether the generator has finished (returned or exhausted).
46    pub done: Rc<std::cell::Cell<bool>>,
47    /// Receiver end of the yield channel (generator sends values here).
48    /// Wrapped in a shared async mutex so recv() can be called without holding
49    /// a RefCell borrow across await points.
50    pub receiver: Rc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
51}
52
53/// VM runtime value.
54#[derive(Debug, Clone)]
55pub enum VmValue {
56    Int(i64),
57    Float(f64),
58    String(Rc<str>),
59    Bool(bool),
60    Nil,
61    List(Rc<Vec<VmValue>>),
62    Dict(Rc<BTreeMap<String, VmValue>>),
63    Closure(Rc<VmClosure>),
64    Duration(u64),
65    EnumVariant {
66        enum_name: String,
67        variant: String,
68        fields: Vec<VmValue>,
69    },
70    StructInstance {
71        struct_name: String,
72        fields: BTreeMap<String, VmValue>,
73    },
74    TaskHandle(String),
75    Channel(VmChannelHandle),
76    Atomic(VmAtomicHandle),
77    McpClient(VmMcpClientHandle),
78    Set(Rc<Vec<VmValue>>),
79    Generator(VmGenerator),
80}
81
82/// A compiled closure value.
83#[derive(Debug, Clone)]
84pub struct VmClosure {
85    pub func: CompiledFunction,
86    pub env: VmEnv,
87    /// Source directory for this closure's originating module.
88    /// When set, `render()` and other source-relative builtins resolve
89    /// paths relative to this directory instead of the entry pipeline.
90    pub source_dir: Option<std::path::PathBuf>,
91}
92
93/// VM environment for variable storage.
94#[derive(Debug, Clone)]
95pub struct VmEnv {
96    pub(crate) scopes: Vec<Scope>,
97}
98
99#[derive(Debug, Clone)]
100pub(crate) struct Scope {
101    pub(crate) vars: BTreeMap<String, (VmValue, bool)>, // (value, mutable)
102}
103
104impl Default for VmEnv {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl VmEnv {
111    pub fn new() -> Self {
112        Self {
113            scopes: vec![Scope {
114                vars: BTreeMap::new(),
115            }],
116        }
117    }
118
119    pub fn push_scope(&mut self) {
120        self.scopes.push(Scope {
121            vars: BTreeMap::new(),
122        });
123    }
124
125    pub fn get(&self, name: &str) -> Option<VmValue> {
126        for scope in self.scopes.iter().rev() {
127            if let Some((val, _)) = scope.vars.get(name) {
128                return Some(val.clone());
129            }
130        }
131        None
132    }
133
134    pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) -> Result<(), VmError> {
135        if let Some(scope) = self.scopes.last_mut() {
136            if let Some((_, existing_mutable)) = scope.vars.get(name) {
137                if !existing_mutable && !mutable {
138                    return Err(VmError::Runtime(format!(
139                        "Cannot redeclare immutable variable '{name}' in the same scope (use 'var' for mutable bindings)"
140                    )));
141                }
142            }
143            scope.vars.insert(name.to_string(), (value, mutable));
144        }
145        Ok(())
146    }
147
148    pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
149        let mut vars = BTreeMap::new();
150        for scope in &self.scopes {
151            for (name, (value, _)) in &scope.vars {
152                vars.insert(name.clone(), value.clone());
153            }
154        }
155        vars
156    }
157
158    pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
159        for scope in self.scopes.iter_mut().rev() {
160            if let Some((_, mutable)) = scope.vars.get(name) {
161                if !mutable {
162                    return Err(VmError::ImmutableAssignment(name.to_string()));
163                }
164                scope.vars.insert(name.to_string(), (value, true));
165                return Ok(());
166            }
167        }
168        Err(VmError::UndefinedVariable(name.to_string()))
169    }
170}
171
172/// VM runtime errors.
173/// Compute Levenshtein edit distance between two strings.
174fn levenshtein(a: &str, b: &str) -> usize {
175    let a: Vec<char> = a.chars().collect();
176    let b: Vec<char> = b.chars().collect();
177    let (m, n) = (a.len(), b.len());
178    let mut prev = (0..=n).collect::<Vec<_>>();
179    let mut curr = vec![0; n + 1];
180    for i in 1..=m {
181        curr[0] = i;
182        for j in 1..=n {
183            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
184            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
185        }
186        std::mem::swap(&mut prev, &mut curr);
187    }
188    prev[n]
189}
190
191/// Find the closest match from a list of candidates using Levenshtein distance.
192/// Returns `Some(suggestion)` if a candidate is within `max_dist` edits.
193pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
194    let max_dist = match name.len() {
195        0..=2 => 1,
196        3..=5 => 2,
197        _ => 3,
198    };
199    candidates
200        .filter(|c| *c != name && !c.starts_with("__"))
201        .map(|c| (c, levenshtein(name, c)))
202        .filter(|(_, d)| *d <= max_dist)
203        // Prefer smallest distance, then closest length to original, then alphabetical
204        .min_by(|(a, da), (b, db)| {
205            da.cmp(db)
206                .then_with(|| {
207                    let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
208                    let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
209                    a_diff.cmp(&b_diff)
210                })
211                .then_with(|| a.cmp(b))
212        })
213        .map(|(c, _)| c.to_string())
214}
215
216#[derive(Debug, Clone)]
217pub enum VmError {
218    StackUnderflow,
219    StackOverflow,
220    UndefinedVariable(String),
221    UndefinedBuiltin(String),
222    ImmutableAssignment(String),
223    TypeError(String),
224    Runtime(String),
225    DivisionByZero,
226    Thrown(VmValue),
227    /// Thrown with error category for structured error handling.
228    CategorizedError {
229        message: String,
230        category: ErrorCategory,
231    },
232    Return(VmValue),
233    InvalidInstruction(u8),
234}
235
236/// Error categories for structured error handling in agent orchestration.
237#[derive(Debug, Clone, PartialEq, Eq)]
238pub enum ErrorCategory {
239    /// Network/connection timeout
240    Timeout,
241    /// Authentication/authorization failure
242    Auth,
243    /// Rate limit exceeded
244    RateLimit,
245    /// Tool execution failure
246    ToolError,
247    /// Tool was rejected by the host (not permitted / not in allowlist)
248    ToolRejected,
249    /// Operation was cancelled
250    Cancelled,
251    /// Resource not found
252    NotFound,
253    /// Circuit breaker is open
254    CircuitOpen,
255    /// Generic/unclassified error
256    Generic,
257}
258
259impl ErrorCategory {
260    pub fn as_str(&self) -> &'static str {
261        match self {
262            ErrorCategory::Timeout => "timeout",
263            ErrorCategory::Auth => "auth",
264            ErrorCategory::RateLimit => "rate_limit",
265            ErrorCategory::ToolError => "tool_error",
266            ErrorCategory::ToolRejected => "tool_rejected",
267            ErrorCategory::Cancelled => "cancelled",
268            ErrorCategory::NotFound => "not_found",
269            ErrorCategory::CircuitOpen => "circuit_open",
270            ErrorCategory::Generic => "generic",
271        }
272    }
273
274    pub fn parse(s: &str) -> Self {
275        match s {
276            "timeout" => ErrorCategory::Timeout,
277            "auth" => ErrorCategory::Auth,
278            "rate_limit" => ErrorCategory::RateLimit,
279            "tool_error" => ErrorCategory::ToolError,
280            "tool_rejected" => ErrorCategory::ToolRejected,
281            "cancelled" => ErrorCategory::Cancelled,
282            "not_found" => ErrorCategory::NotFound,
283            "circuit_open" => ErrorCategory::CircuitOpen,
284            _ => ErrorCategory::Generic,
285        }
286    }
287}
288
289/// Create a categorized error conveniently.
290pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
291    VmError::CategorizedError {
292        message: message.into(),
293        category,
294    }
295}
296
297/// Extract error category from a VmError.
298///
299/// Classification priority:
300/// 1. Explicit CategorizedError variant (set by throw_error or internal code)
301/// 2. Thrown dict with a "category" field (user-created structured errors)
302/// 3. HTTP status code extraction (standard, unambiguous)
303/// 4. Deadline exceeded (VM-internal)
304/// 5. Fallback to Generic
305pub fn error_to_category(err: &VmError) -> ErrorCategory {
306    match err {
307        VmError::CategorizedError { category, .. } => category.clone(),
308        VmError::Thrown(VmValue::Dict(d)) => d
309            .get("category")
310            .map(|v| ErrorCategory::parse(&v.display()))
311            .unwrap_or(ErrorCategory::Generic),
312        VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
313        VmError::Runtime(msg) => classify_error_message(msg),
314        _ => ErrorCategory::Generic,
315    }
316}
317
318/// Classify an error message using HTTP status codes and well-known patterns.
319/// Prefers unambiguous signals (status codes) over substring heuristics.
320fn classify_error_message(msg: &str) -> ErrorCategory {
321    // 1. HTTP status codes — most reliable signal
322    if let Some(cat) = classify_by_http_status(msg) {
323        return cat;
324    }
325    // 2. Well-known error identifiers from major APIs
326    //    (Anthropic, OpenAI, and standard HTTP patterns)
327    if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
328        return ErrorCategory::Timeout;
329    }
330    if msg.contains("overloaded_error") || msg.contains("api_error") {
331        // Anthropic-specific error types
332        return ErrorCategory::RateLimit;
333    }
334    if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
335        // OpenAI-specific error types
336        return ErrorCategory::RateLimit;
337    }
338    if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
339        return ErrorCategory::Auth;
340    }
341    if msg.contains("not_found_error") || msg.contains("model_not_found") {
342        return ErrorCategory::NotFound;
343    }
344    if msg.contains("circuit_open") {
345        return ErrorCategory::CircuitOpen;
346    }
347    ErrorCategory::Generic
348}
349
350/// Classify errors by HTTP status code if one appears in the message.
351/// This is the most reliable classification method since status codes
352/// are standardized (RFC 9110) and unambiguous.
353fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
354    // Extract 3-digit HTTP status codes from common patterns:
355    // "HTTP 429", "status 429", "429 Too Many", "error: 401"
356    for code in extract_http_status_codes(msg) {
357        return Some(match code {
358            401 | 403 => ErrorCategory::Auth,
359            404 | 410 => ErrorCategory::NotFound,
360            408 | 504 | 522 | 524 => ErrorCategory::Timeout,
361            429 | 503 => ErrorCategory::RateLimit,
362            _ => continue,
363        });
364    }
365    None
366}
367
368/// Extract plausible HTTP status codes from an error message.
369fn extract_http_status_codes(msg: &str) -> Vec<u16> {
370    let mut codes = Vec::new();
371    let bytes = msg.as_bytes();
372    for i in 0..bytes.len().saturating_sub(2) {
373        // Look for 3-digit sequences in the 100-599 range
374        if bytes[i].is_ascii_digit()
375            && bytes[i + 1].is_ascii_digit()
376            && bytes[i + 2].is_ascii_digit()
377        {
378            // Ensure it's not part of a longer number
379            let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
380            let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
381            if before_ok && after_ok {
382                if let Ok(code) = msg[i..i + 3].parse::<u16>() {
383                    if (400..=599).contains(&code) {
384                        codes.push(code);
385                    }
386                }
387            }
388        }
389    }
390    codes
391}
392
393impl std::fmt::Display for VmError {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        match self {
396            VmError::StackUnderflow => write!(f, "Stack underflow"),
397            VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
398            VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
399            VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
400            VmError::ImmutableAssignment(n) => {
401                write!(f, "Cannot assign to immutable binding: {n}")
402            }
403            VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
404            VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
405            VmError::DivisionByZero => write!(f, "Division by zero"),
406            VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
407            VmError::CategorizedError { message, category } => {
408                write!(f, "Error [{}]: {}", category.as_str(), message)
409            }
410            VmError::Return(_) => write!(f, "Return from function"),
411            VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
412        }
413    }
414}
415
416impl std::error::Error for VmError {}
417
418impl VmValue {
419    pub fn is_truthy(&self) -> bool {
420        match self {
421            VmValue::Bool(b) => *b,
422            VmValue::Nil => false,
423            VmValue::Int(n) => *n != 0,
424            VmValue::Float(n) => *n != 0.0,
425            VmValue::String(s) => !s.is_empty(),
426            VmValue::List(l) => !l.is_empty(),
427            VmValue::Dict(d) => !d.is_empty(),
428            VmValue::Closure(_) => true,
429            VmValue::Duration(ms) => *ms > 0,
430            VmValue::EnumVariant { .. } => true,
431            VmValue::StructInstance { .. } => true,
432            VmValue::TaskHandle(_) => true,
433            VmValue::Channel(_) => true,
434            VmValue::Atomic(_) => true,
435            VmValue::McpClient(_) => true,
436            VmValue::Set(s) => !s.is_empty(),
437            VmValue::Generator(_) => true,
438        }
439    }
440
441    pub fn type_name(&self) -> &'static str {
442        match self {
443            VmValue::String(_) => "string",
444            VmValue::Int(_) => "int",
445            VmValue::Float(_) => "float",
446            VmValue::Bool(_) => "bool",
447            VmValue::Nil => "nil",
448            VmValue::List(_) => "list",
449            VmValue::Dict(_) => "dict",
450            VmValue::Closure(_) => "closure",
451            VmValue::Duration(_) => "duration",
452            VmValue::EnumVariant { .. } => "enum",
453            VmValue::StructInstance { .. } => "struct",
454            VmValue::TaskHandle(_) => "task_handle",
455            VmValue::Channel(_) => "channel",
456            VmValue::Atomic(_) => "atomic",
457            VmValue::McpClient(_) => "mcp_client",
458            VmValue::Set(_) => "set",
459            VmValue::Generator(_) => "generator",
460        }
461    }
462
463    pub fn display(&self) -> String {
464        match self {
465            VmValue::Int(n) => n.to_string(),
466            VmValue::Float(n) => {
467                if *n == (*n as i64) as f64 && n.abs() < 1e15 {
468                    format!("{:.1}", n)
469                } else {
470                    n.to_string()
471                }
472            }
473            VmValue::String(s) => s.to_string(),
474            VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
475            VmValue::Nil => "nil".to_string(),
476            VmValue::List(items) => {
477                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
478                format!("[{}]", inner.join(", "))
479            }
480            VmValue::Dict(map) => {
481                let inner: Vec<String> = map
482                    .iter()
483                    .map(|(k, v)| format!("{k}: {}", v.display()))
484                    .collect();
485                format!("{{{}}}", inner.join(", "))
486            }
487            VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
488            VmValue::Duration(ms) => {
489                if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
490                    format!("{}h", ms / 3_600_000)
491                } else if *ms >= 60_000 && ms % 60_000 == 0 {
492                    format!("{}m", ms / 60_000)
493                } else if *ms >= 1000 && ms % 1000 == 0 {
494                    format!("{}s", ms / 1000)
495                } else {
496                    format!("{}ms", ms)
497                }
498            }
499            VmValue::EnumVariant {
500                enum_name,
501                variant,
502                fields,
503            } => {
504                if fields.is_empty() {
505                    format!("{enum_name}.{variant}")
506                } else {
507                    let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
508                    format!("{enum_name}.{variant}({})", inner.join(", "))
509                }
510            }
511            VmValue::StructInstance {
512                struct_name,
513                fields,
514            } => {
515                let inner: Vec<String> = fields
516                    .iter()
517                    .map(|(k, v)| format!("{k}: {}", v.display()))
518                    .collect();
519                format!("{struct_name} {{{}}}", inner.join(", "))
520            }
521            VmValue::TaskHandle(id) => format!("<task:{id}>"),
522            VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
523            VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
524            VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
525            VmValue::Set(items) => {
526                let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
527                format!("set({})", inner.join(", "))
528            }
529            VmValue::Generator(g) => {
530                if g.done.get() {
531                    "<generator (done)>".to_string()
532                } else {
533                    "<generator>".to_string()
534                }
535            }
536        }
537    }
538
539    /// Get the value as a BTreeMap reference, if it's a Dict.
540    pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
541        if let VmValue::Dict(d) = self {
542            Some(d)
543        } else {
544            None
545        }
546    }
547
548    pub fn as_int(&self) -> Option<i64> {
549        if let VmValue::Int(n) = self {
550            Some(*n)
551        } else {
552            None
553        }
554    }
555}
556
557/// Sync builtin function for the VM.
558pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
559
560pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
561    match (a, b) {
562        (VmValue::Int(x), VmValue::Int(y)) => x == y,
563        (VmValue::Float(x), VmValue::Float(y)) => x == y,
564        (VmValue::String(x), VmValue::String(y)) => x == y,
565        (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
566        (VmValue::Nil, VmValue::Nil) => true,
567        (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
568        (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
569        (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
570        (VmValue::Channel(_), VmValue::Channel(_)) => false, // channels are never equal
571        (VmValue::Atomic(a), VmValue::Atomic(b)) => {
572            a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
573        }
574        (VmValue::List(a), VmValue::List(b)) => {
575            a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
576        }
577        (VmValue::Dict(a), VmValue::Dict(b)) => {
578            a.len() == b.len()
579                && a.iter()
580                    .zip(b.iter())
581                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
582        }
583        (
584            VmValue::EnumVariant {
585                enum_name: a_e,
586                variant: a_v,
587                fields: a_f,
588            },
589            VmValue::EnumVariant {
590                enum_name: b_e,
591                variant: b_v,
592                fields: b_f,
593            },
594        ) => {
595            a_e == b_e
596                && a_v == b_v
597                && a_f.len() == b_f.len()
598                && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
599        }
600        (
601            VmValue::StructInstance {
602                struct_name: a_s,
603                fields: a_f,
604            },
605            VmValue::StructInstance {
606                struct_name: b_s,
607                fields: b_f,
608            },
609        ) => {
610            a_s == b_s
611                && a_f.len() == b_f.len()
612                && a_f
613                    .iter()
614                    .zip(b_f.iter())
615                    .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
616        }
617        (VmValue::Set(a), VmValue::Set(b)) => {
618            a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
619        }
620        (VmValue::Generator(_), VmValue::Generator(_)) => false, // generators are never equal
621        _ => false,
622    }
623}
624
625pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
626    match (a, b) {
627        (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
628        (VmValue::Float(x), VmValue::Float(y)) => {
629            if x < y {
630                -1
631            } else if x > y {
632                1
633            } else {
634                0
635            }
636        }
637        (VmValue::Int(x), VmValue::Float(y)) => {
638            let x = *x as f64;
639            if x < *y {
640                -1
641            } else if x > *y {
642                1
643            } else {
644                0
645            }
646        }
647        (VmValue::Float(x), VmValue::Int(y)) => {
648            let y = *y as f64;
649            if *x < y {
650                -1
651            } else if *x > y {
652                1
653            } else {
654                0
655            }
656        }
657        (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
658        _ => 0,
659    }
660}