1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5use std::{cell::RefCell, path::PathBuf};
6
7use crate::chunk::CompiledFunction;
8use crate::mcp::VmMcpClientHandle;
9
10pub type VmAsyncBuiltinFn = Rc<
12 dyn Fn(
13 Vec<VmValue>,
14 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
15>;
16
17pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
19
20pub struct VmTaskHandle {
22 pub handle: VmJoinHandle,
23 pub cancel_token: Arc<AtomicBool>,
25}
26
27#[derive(Debug, Clone)]
29pub struct VmChannelHandle {
30 pub name: String,
31 pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
32 pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
33 pub closed: Arc<AtomicBool>,
34}
35
36#[derive(Debug, Clone)]
38pub struct VmAtomicHandle {
39 pub value: Arc<AtomicI64>,
40}
41
42#[derive(Debug, Clone)]
45pub struct VmGenerator {
46 pub done: Rc<std::cell::Cell<bool>>,
48 pub receiver: Rc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
52}
53
54#[derive(Debug, Clone)]
56pub enum VmValue {
57 Int(i64),
58 Float(f64),
59 String(Rc<str>),
60 Bool(bool),
61 Nil,
62 List(Rc<Vec<VmValue>>),
63 Dict(Rc<BTreeMap<String, VmValue>>),
64 Closure(Rc<VmClosure>),
65 Duration(u64),
66 EnumVariant {
67 enum_name: String,
68 variant: String,
69 fields: Vec<VmValue>,
70 },
71 StructInstance {
72 struct_name: String,
73 fields: BTreeMap<String, VmValue>,
74 },
75 TaskHandle(String),
76 Channel(VmChannelHandle),
77 Atomic(VmAtomicHandle),
78 McpClient(VmMcpClientHandle),
79 Set(Rc<Vec<VmValue>>),
80 Generator(VmGenerator),
81}
82
83#[derive(Debug, Clone)]
85pub struct VmClosure {
86 pub func: CompiledFunction,
87 pub env: VmEnv,
88 pub source_dir: Option<PathBuf>,
92 pub module_functions: Option<ModuleFunctionRegistry>,
96}
97
98pub type ModuleFunctionRegistry = Rc<RefCell<BTreeMap<String, Rc<VmClosure>>>>;
99
100#[derive(Debug, Clone)]
102pub struct VmEnv {
103 pub(crate) scopes: Vec<Scope>,
104}
105
106#[derive(Debug, Clone)]
107pub(crate) struct Scope {
108 pub(crate) vars: BTreeMap<String, (VmValue, bool)>, }
110
111impl Default for VmEnv {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl VmEnv {
118 pub fn new() -> Self {
119 Self {
120 scopes: vec![Scope {
121 vars: BTreeMap::new(),
122 }],
123 }
124 }
125
126 pub fn push_scope(&mut self) {
127 self.scopes.push(Scope {
128 vars: BTreeMap::new(),
129 });
130 }
131
132 pub fn pop_scope(&mut self) {
133 if self.scopes.len() > 1 {
134 self.scopes.pop();
135 }
136 }
137
138 pub fn scope_depth(&self) -> usize {
139 self.scopes.len()
140 }
141
142 pub fn truncate_scopes(&mut self, target_depth: usize) {
143 let min_depth = target_depth.max(1);
144 while self.scopes.len() > min_depth {
145 self.scopes.pop();
146 }
147 }
148
149 pub fn get(&self, name: &str) -> Option<VmValue> {
150 for scope in self.scopes.iter().rev() {
151 if let Some((val, _)) = scope.vars.get(name) {
152 return Some(val.clone());
153 }
154 }
155 None
156 }
157
158 pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) -> Result<(), VmError> {
159 if let Some(scope) = self.scopes.last_mut() {
160 if let Some((_, existing_mutable)) = scope.vars.get(name) {
161 if !existing_mutable && !mutable {
162 return Err(VmError::Runtime(format!(
163 "Cannot redeclare immutable variable '{name}' in the same scope (use 'var' for mutable bindings)"
164 )));
165 }
166 }
167 scope.vars.insert(name.to_string(), (value, mutable));
168 }
169 Ok(())
170 }
171
172 pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
173 let mut vars = BTreeMap::new();
174 for scope in &self.scopes {
175 for (name, (value, _)) in &scope.vars {
176 vars.insert(name.clone(), value.clone());
177 }
178 }
179 vars
180 }
181
182 pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
183 for scope in self.scopes.iter_mut().rev() {
184 if let Some((_, mutable)) = scope.vars.get(name) {
185 if !mutable {
186 return Err(VmError::ImmutableAssignment(name.to_string()));
187 }
188 scope.vars.insert(name.to_string(), (value, true));
189 return Ok(());
190 }
191 }
192 Err(VmError::UndefinedVariable(name.to_string()))
193 }
194}
195
196fn levenshtein(a: &str, b: &str) -> usize {
199 let a: Vec<char> = a.chars().collect();
200 let b: Vec<char> = b.chars().collect();
201 let (m, n) = (a.len(), b.len());
202 let mut prev = (0..=n).collect::<Vec<_>>();
203 let mut curr = vec![0; n + 1];
204 for i in 1..=m {
205 curr[0] = i;
206 for j in 1..=n {
207 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
208 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
209 }
210 std::mem::swap(&mut prev, &mut curr);
211 }
212 prev[n]
213}
214
215pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
218 let max_dist = match name.len() {
219 0..=2 => 1,
220 3..=5 => 2,
221 _ => 3,
222 };
223 candidates
224 .filter(|c| *c != name && !c.starts_with("__"))
225 .map(|c| (c, levenshtein(name, c)))
226 .filter(|(_, d)| *d <= max_dist)
227 .min_by(|(a, da), (b, db)| {
229 da.cmp(db)
230 .then_with(|| {
231 let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
232 let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
233 a_diff.cmp(&b_diff)
234 })
235 .then_with(|| a.cmp(b))
236 })
237 .map(|(c, _)| c.to_string())
238}
239
240#[derive(Debug, Clone)]
241pub enum VmError {
242 StackUnderflow,
243 StackOverflow,
244 UndefinedVariable(String),
245 UndefinedBuiltin(String),
246 ImmutableAssignment(String),
247 TypeError(String),
248 Runtime(String),
249 DivisionByZero,
250 Thrown(VmValue),
251 CategorizedError {
253 message: String,
254 category: ErrorCategory,
255 },
256 Return(VmValue),
257 InvalidInstruction(u8),
258}
259
260#[derive(Debug, Clone, PartialEq, Eq)]
262pub enum ErrorCategory {
263 Timeout,
265 Auth,
267 RateLimit,
269 ToolError,
271 ToolRejected,
273 Cancelled,
275 NotFound,
277 CircuitOpen,
279 Generic,
281}
282
283impl ErrorCategory {
284 pub fn as_str(&self) -> &'static str {
285 match self {
286 ErrorCategory::Timeout => "timeout",
287 ErrorCategory::Auth => "auth",
288 ErrorCategory::RateLimit => "rate_limit",
289 ErrorCategory::ToolError => "tool_error",
290 ErrorCategory::ToolRejected => "tool_rejected",
291 ErrorCategory::Cancelled => "cancelled",
292 ErrorCategory::NotFound => "not_found",
293 ErrorCategory::CircuitOpen => "circuit_open",
294 ErrorCategory::Generic => "generic",
295 }
296 }
297
298 pub fn parse(s: &str) -> Self {
299 match s {
300 "timeout" => ErrorCategory::Timeout,
301 "auth" => ErrorCategory::Auth,
302 "rate_limit" => ErrorCategory::RateLimit,
303 "tool_error" => ErrorCategory::ToolError,
304 "tool_rejected" => ErrorCategory::ToolRejected,
305 "cancelled" => ErrorCategory::Cancelled,
306 "not_found" => ErrorCategory::NotFound,
307 "circuit_open" => ErrorCategory::CircuitOpen,
308 _ => ErrorCategory::Generic,
309 }
310 }
311}
312
313pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
315 VmError::CategorizedError {
316 message: message.into(),
317 category,
318 }
319}
320
321pub fn error_to_category(err: &VmError) -> ErrorCategory {
330 match err {
331 VmError::CategorizedError { category, .. } => category.clone(),
332 VmError::Thrown(VmValue::Dict(d)) => d
333 .get("category")
334 .map(|v| ErrorCategory::parse(&v.display()))
335 .unwrap_or(ErrorCategory::Generic),
336 VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
337 VmError::Runtime(msg) => classify_error_message(msg),
338 _ => ErrorCategory::Generic,
339 }
340}
341
342fn classify_error_message(msg: &str) -> ErrorCategory {
345 if let Some(cat) = classify_by_http_status(msg) {
347 return cat;
348 }
349 if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
352 return ErrorCategory::Timeout;
353 }
354 if msg.contains("overloaded_error") || msg.contains("api_error") {
355 return ErrorCategory::RateLimit;
357 }
358 if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
359 return ErrorCategory::RateLimit;
361 }
362 if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
363 return ErrorCategory::Auth;
364 }
365 if msg.contains("not_found_error") || msg.contains("model_not_found") {
366 return ErrorCategory::NotFound;
367 }
368 if msg.contains("circuit_open") {
369 return ErrorCategory::CircuitOpen;
370 }
371 ErrorCategory::Generic
372}
373
374fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
378 for code in extract_http_status_codes(msg) {
381 return Some(match code {
382 401 | 403 => ErrorCategory::Auth,
383 404 | 410 => ErrorCategory::NotFound,
384 408 | 504 | 522 | 524 => ErrorCategory::Timeout,
385 429 | 503 => ErrorCategory::RateLimit,
386 _ => continue,
387 });
388 }
389 None
390}
391
392fn extract_http_status_codes(msg: &str) -> Vec<u16> {
394 let mut codes = Vec::new();
395 let bytes = msg.as_bytes();
396 for i in 0..bytes.len().saturating_sub(2) {
397 if bytes[i].is_ascii_digit()
399 && bytes[i + 1].is_ascii_digit()
400 && bytes[i + 2].is_ascii_digit()
401 {
402 let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
404 let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
405 if before_ok && after_ok {
406 if let Ok(code) = msg[i..i + 3].parse::<u16>() {
407 if (400..=599).contains(&code) {
408 codes.push(code);
409 }
410 }
411 }
412 }
413 }
414 codes
415}
416
417impl std::fmt::Display for VmError {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 match self {
420 VmError::StackUnderflow => write!(f, "Stack underflow"),
421 VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
422 VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
423 VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
424 VmError::ImmutableAssignment(n) => {
425 write!(f, "Cannot assign to immutable binding: {n}")
426 }
427 VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
428 VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
429 VmError::DivisionByZero => write!(f, "Division by zero"),
430 VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
431 VmError::CategorizedError { message, category } => {
432 write!(f, "Error [{}]: {}", category.as_str(), message)
433 }
434 VmError::Return(_) => write!(f, "Return from function"),
435 VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
436 }
437 }
438}
439
440impl std::error::Error for VmError {}
441
442impl VmValue {
443 pub fn is_truthy(&self) -> bool {
444 match self {
445 VmValue::Bool(b) => *b,
446 VmValue::Nil => false,
447 VmValue::Int(n) => *n != 0,
448 VmValue::Float(n) => *n != 0.0,
449 VmValue::String(s) => !s.is_empty(),
450 VmValue::List(l) => !l.is_empty(),
451 VmValue::Dict(d) => !d.is_empty(),
452 VmValue::Closure(_) => true,
453 VmValue::Duration(ms) => *ms > 0,
454 VmValue::EnumVariant { .. } => true,
455 VmValue::StructInstance { .. } => true,
456 VmValue::TaskHandle(_) => true,
457 VmValue::Channel(_) => true,
458 VmValue::Atomic(_) => true,
459 VmValue::McpClient(_) => true,
460 VmValue::Set(s) => !s.is_empty(),
461 VmValue::Generator(_) => true,
462 }
463 }
464
465 pub fn type_name(&self) -> &'static str {
466 match self {
467 VmValue::String(_) => "string",
468 VmValue::Int(_) => "int",
469 VmValue::Float(_) => "float",
470 VmValue::Bool(_) => "bool",
471 VmValue::Nil => "nil",
472 VmValue::List(_) => "list",
473 VmValue::Dict(_) => "dict",
474 VmValue::Closure(_) => "closure",
475 VmValue::Duration(_) => "duration",
476 VmValue::EnumVariant { .. } => "enum",
477 VmValue::StructInstance { .. } => "struct",
478 VmValue::TaskHandle(_) => "task_handle",
479 VmValue::Channel(_) => "channel",
480 VmValue::Atomic(_) => "atomic",
481 VmValue::McpClient(_) => "mcp_client",
482 VmValue::Set(_) => "set",
483 VmValue::Generator(_) => "generator",
484 }
485 }
486
487 pub fn display(&self) -> String {
488 match self {
489 VmValue::Int(n) => n.to_string(),
490 VmValue::Float(n) => {
491 if *n == (*n as i64) as f64 && n.abs() < 1e15 {
492 format!("{:.1}", n)
493 } else {
494 n.to_string()
495 }
496 }
497 VmValue::String(s) => s.to_string(),
498 VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
499 VmValue::Nil => "nil".to_string(),
500 VmValue::List(items) => {
501 let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
502 format!("[{}]", inner.join(", "))
503 }
504 VmValue::Dict(map) => {
505 let inner: Vec<String> = map
506 .iter()
507 .map(|(k, v)| format!("{k}: {}", v.display()))
508 .collect();
509 format!("{{{}}}", inner.join(", "))
510 }
511 VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
512 VmValue::Duration(ms) => {
513 if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
514 format!("{}h", ms / 3_600_000)
515 } else if *ms >= 60_000 && ms % 60_000 == 0 {
516 format!("{}m", ms / 60_000)
517 } else if *ms >= 1000 && ms % 1000 == 0 {
518 format!("{}s", ms / 1000)
519 } else {
520 format!("{}ms", ms)
521 }
522 }
523 VmValue::EnumVariant {
524 enum_name,
525 variant,
526 fields,
527 } => {
528 if fields.is_empty() {
529 format!("{enum_name}.{variant}")
530 } else {
531 let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
532 format!("{enum_name}.{variant}({})", inner.join(", "))
533 }
534 }
535 VmValue::StructInstance {
536 struct_name,
537 fields,
538 } => {
539 let inner: Vec<String> = fields
540 .iter()
541 .map(|(k, v)| format!("{k}: {}", v.display()))
542 .collect();
543 format!("{struct_name} {{{}}}", inner.join(", "))
544 }
545 VmValue::TaskHandle(id) => format!("<task:{id}>"),
546 VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
547 VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
548 VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
549 VmValue::Set(items) => {
550 let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
551 format!("set({})", inner.join(", "))
552 }
553 VmValue::Generator(g) => {
554 if g.done.get() {
555 "<generator (done)>".to_string()
556 } else {
557 "<generator>".to_string()
558 }
559 }
560 }
561 }
562
563 pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
565 if let VmValue::Dict(d) = self {
566 Some(d)
567 } else {
568 None
569 }
570 }
571
572 pub fn as_int(&self) -> Option<i64> {
573 if let VmValue::Int(n) = self {
574 Some(*n)
575 } else {
576 None
577 }
578 }
579}
580
581pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
583
584pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
585 match (a, b) {
586 (VmValue::Int(x), VmValue::Int(y)) => x == y,
587 (VmValue::Float(x), VmValue::Float(y)) => x == y,
588 (VmValue::String(x), VmValue::String(y)) => x == y,
589 (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
590 (VmValue::Nil, VmValue::Nil) => true,
591 (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
592 (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
593 (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
594 (VmValue::Channel(_), VmValue::Channel(_)) => false, (VmValue::Atomic(a), VmValue::Atomic(b)) => {
596 a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
597 }
598 (VmValue::List(a), VmValue::List(b)) => {
599 a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
600 }
601 (VmValue::Dict(a), VmValue::Dict(b)) => {
602 a.len() == b.len()
603 && a.iter()
604 .zip(b.iter())
605 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
606 }
607 (
608 VmValue::EnumVariant {
609 enum_name: a_e,
610 variant: a_v,
611 fields: a_f,
612 },
613 VmValue::EnumVariant {
614 enum_name: b_e,
615 variant: b_v,
616 fields: b_f,
617 },
618 ) => {
619 a_e == b_e
620 && a_v == b_v
621 && a_f.len() == b_f.len()
622 && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
623 }
624 (
625 VmValue::StructInstance {
626 struct_name: a_s,
627 fields: a_f,
628 },
629 VmValue::StructInstance {
630 struct_name: b_s,
631 fields: b_f,
632 },
633 ) => {
634 a_s == b_s
635 && a_f.len() == b_f.len()
636 && a_f
637 .iter()
638 .zip(b_f.iter())
639 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
640 }
641 (VmValue::Set(a), VmValue::Set(b)) => {
642 a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
643 }
644 (VmValue::Generator(_), VmValue::Generator(_)) => false, _ => false,
646 }
647}
648
649pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
650 match (a, b) {
651 (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
652 (VmValue::Float(x), VmValue::Float(y)) => {
653 if x < y {
654 -1
655 } else if x > y {
656 1
657 } else {
658 0
659 }
660 }
661 (VmValue::Int(x), VmValue::Float(y)) => {
662 let x = *x as f64;
663 if x < *y {
664 -1
665 } else if x > *y {
666 1
667 } else {
668 0
669 }
670 }
671 (VmValue::Float(x), VmValue::Int(y)) => {
672 let y = *y as f64;
673 if *x < y {
674 -1
675 } else if *x > y {
676 1
677 } else {
678 0
679 }
680 }
681 (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
682 _ => 0,
683 }
684}