1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9pub type VmAsyncBuiltinFn = Rc<
11 dyn Fn(
12 Vec<VmValue>,
13 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19pub struct VmTaskHandle {
21 pub handle: VmJoinHandle,
22 pub cancel_token: Arc<AtomicBool>,
24}
25
26#[derive(Debug, Clone)]
28pub struct VmChannelHandle {
29 pub name: String,
30 pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
31 pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
32 pub closed: Arc<AtomicBool>,
33}
34
35#[derive(Debug, Clone)]
37pub struct VmAtomicHandle {
38 pub value: Arc<AtomicI64>,
39}
40
41#[derive(Debug, Clone)]
44pub struct VmGenerator {
45 pub done: Rc<std::cell::Cell<bool>>,
47 pub receiver: Rc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
51}
52
53#[derive(Debug, Clone)]
55pub enum VmValue {
56 Int(i64),
57 Float(f64),
58 String(Rc<str>),
59 Bool(bool),
60 Nil,
61 List(Rc<Vec<VmValue>>),
62 Dict(Rc<BTreeMap<String, VmValue>>),
63 Closure(Rc<VmClosure>),
64 Duration(u64),
65 EnumVariant {
66 enum_name: String,
67 variant: String,
68 fields: Vec<VmValue>,
69 },
70 StructInstance {
71 struct_name: String,
72 fields: BTreeMap<String, VmValue>,
73 },
74 TaskHandle(String),
75 Channel(VmChannelHandle),
76 Atomic(VmAtomicHandle),
77 McpClient(VmMcpClientHandle),
78 Set(Rc<Vec<VmValue>>),
79 Generator(VmGenerator),
80}
81
82#[derive(Debug, Clone)]
84pub struct VmClosure {
85 pub func: CompiledFunction,
86 pub env: VmEnv,
87 pub source_dir: Option<std::path::PathBuf>,
91}
92
93#[derive(Debug, Clone)]
95pub struct VmEnv {
96 pub(crate) scopes: Vec<Scope>,
97}
98
99#[derive(Debug, Clone)]
100pub(crate) struct Scope {
101 pub(crate) vars: BTreeMap<String, (VmValue, bool)>, }
103
104impl Default for VmEnv {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl VmEnv {
111 pub fn new() -> Self {
112 Self {
113 scopes: vec![Scope {
114 vars: BTreeMap::new(),
115 }],
116 }
117 }
118
119 pub fn push_scope(&mut self) {
120 self.scopes.push(Scope {
121 vars: BTreeMap::new(),
122 });
123 }
124
125 pub fn get(&self, name: &str) -> Option<VmValue> {
126 for scope in self.scopes.iter().rev() {
127 if let Some((val, _)) = scope.vars.get(name) {
128 return Some(val.clone());
129 }
130 }
131 None
132 }
133
134 pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) -> Result<(), VmError> {
135 if let Some(scope) = self.scopes.last_mut() {
136 if let Some((_, existing_mutable)) = scope.vars.get(name) {
137 if !existing_mutable && !mutable {
138 return Err(VmError::Runtime(format!(
139 "Cannot redeclare immutable variable '{name}' in the same scope (use 'var' for mutable bindings)"
140 )));
141 }
142 }
143 scope.vars.insert(name.to_string(), (value, mutable));
144 }
145 Ok(())
146 }
147
148 pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
149 let mut vars = BTreeMap::new();
150 for scope in &self.scopes {
151 for (name, (value, _)) in &scope.vars {
152 vars.insert(name.clone(), value.clone());
153 }
154 }
155 vars
156 }
157
158 pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
159 for scope in self.scopes.iter_mut().rev() {
160 if let Some((_, mutable)) = scope.vars.get(name) {
161 if !mutable {
162 return Err(VmError::ImmutableAssignment(name.to_string()));
163 }
164 scope.vars.insert(name.to_string(), (value, true));
165 return Ok(());
166 }
167 }
168 Err(VmError::UndefinedVariable(name.to_string()))
169 }
170}
171
172fn levenshtein(a: &str, b: &str) -> usize {
175 let a: Vec<char> = a.chars().collect();
176 let b: Vec<char> = b.chars().collect();
177 let (m, n) = (a.len(), b.len());
178 let mut prev = (0..=n).collect::<Vec<_>>();
179 let mut curr = vec![0; n + 1];
180 for i in 1..=m {
181 curr[0] = i;
182 for j in 1..=n {
183 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
184 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
185 }
186 std::mem::swap(&mut prev, &mut curr);
187 }
188 prev[n]
189}
190
191pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
194 let max_dist = match name.len() {
195 0..=2 => 1,
196 3..=5 => 2,
197 _ => 3,
198 };
199 candidates
200 .filter(|c| *c != name && !c.starts_with("__"))
201 .map(|c| (c, levenshtein(name, c)))
202 .filter(|(_, d)| *d <= max_dist)
203 .min_by(|(a, da), (b, db)| {
205 da.cmp(db)
206 .then_with(|| {
207 let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
208 let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
209 a_diff.cmp(&b_diff)
210 })
211 .then_with(|| a.cmp(b))
212 })
213 .map(|(c, _)| c.to_string())
214}
215
216#[derive(Debug, Clone)]
217pub enum VmError {
218 StackUnderflow,
219 StackOverflow,
220 UndefinedVariable(String),
221 UndefinedBuiltin(String),
222 ImmutableAssignment(String),
223 TypeError(String),
224 Runtime(String),
225 DivisionByZero,
226 Thrown(VmValue),
227 CategorizedError {
229 message: String,
230 category: ErrorCategory,
231 },
232 Return(VmValue),
233 InvalidInstruction(u8),
234}
235
236#[derive(Debug, Clone, PartialEq, Eq)]
238pub enum ErrorCategory {
239 Timeout,
241 Auth,
243 RateLimit,
245 ToolError,
247 ToolRejected,
249 Cancelled,
251 NotFound,
253 CircuitOpen,
255 Generic,
257}
258
259impl ErrorCategory {
260 pub fn as_str(&self) -> &'static str {
261 match self {
262 ErrorCategory::Timeout => "timeout",
263 ErrorCategory::Auth => "auth",
264 ErrorCategory::RateLimit => "rate_limit",
265 ErrorCategory::ToolError => "tool_error",
266 ErrorCategory::ToolRejected => "tool_rejected",
267 ErrorCategory::Cancelled => "cancelled",
268 ErrorCategory::NotFound => "not_found",
269 ErrorCategory::CircuitOpen => "circuit_open",
270 ErrorCategory::Generic => "generic",
271 }
272 }
273
274 pub fn parse(s: &str) -> Self {
275 match s {
276 "timeout" => ErrorCategory::Timeout,
277 "auth" => ErrorCategory::Auth,
278 "rate_limit" => ErrorCategory::RateLimit,
279 "tool_error" => ErrorCategory::ToolError,
280 "tool_rejected" => ErrorCategory::ToolRejected,
281 "cancelled" => ErrorCategory::Cancelled,
282 "not_found" => ErrorCategory::NotFound,
283 "circuit_open" => ErrorCategory::CircuitOpen,
284 _ => ErrorCategory::Generic,
285 }
286 }
287}
288
289pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
291 VmError::CategorizedError {
292 message: message.into(),
293 category,
294 }
295}
296
297pub fn error_to_category(err: &VmError) -> ErrorCategory {
306 match err {
307 VmError::CategorizedError { category, .. } => category.clone(),
308 VmError::Thrown(VmValue::Dict(d)) => d
309 .get("category")
310 .map(|v| ErrorCategory::parse(&v.display()))
311 .unwrap_or(ErrorCategory::Generic),
312 VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
313 VmError::Runtime(msg) => classify_error_message(msg),
314 _ => ErrorCategory::Generic,
315 }
316}
317
318fn classify_error_message(msg: &str) -> ErrorCategory {
321 if let Some(cat) = classify_by_http_status(msg) {
323 return cat;
324 }
325 if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
328 return ErrorCategory::Timeout;
329 }
330 if msg.contains("overloaded_error") || msg.contains("api_error") {
331 return ErrorCategory::RateLimit;
333 }
334 if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
335 return ErrorCategory::RateLimit;
337 }
338 if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
339 return ErrorCategory::Auth;
340 }
341 if msg.contains("not_found_error") || msg.contains("model_not_found") {
342 return ErrorCategory::NotFound;
343 }
344 if msg.contains("circuit_open") {
345 return ErrorCategory::CircuitOpen;
346 }
347 ErrorCategory::Generic
348}
349
350fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
354 for code in extract_http_status_codes(msg) {
357 return Some(match code {
358 401 | 403 => ErrorCategory::Auth,
359 404 | 410 => ErrorCategory::NotFound,
360 408 | 504 | 522 | 524 => ErrorCategory::Timeout,
361 429 | 503 => ErrorCategory::RateLimit,
362 _ => continue,
363 });
364 }
365 None
366}
367
368fn extract_http_status_codes(msg: &str) -> Vec<u16> {
370 let mut codes = Vec::new();
371 let bytes = msg.as_bytes();
372 for i in 0..bytes.len().saturating_sub(2) {
373 if bytes[i].is_ascii_digit()
375 && bytes[i + 1].is_ascii_digit()
376 && bytes[i + 2].is_ascii_digit()
377 {
378 let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
380 let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
381 if before_ok && after_ok {
382 if let Ok(code) = msg[i..i + 3].parse::<u16>() {
383 if (400..=599).contains(&code) {
384 codes.push(code);
385 }
386 }
387 }
388 }
389 }
390 codes
391}
392
393impl std::fmt::Display for VmError {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 match self {
396 VmError::StackUnderflow => write!(f, "Stack underflow"),
397 VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
398 VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
399 VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
400 VmError::ImmutableAssignment(n) => {
401 write!(f, "Cannot assign to immutable binding: {n}")
402 }
403 VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
404 VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
405 VmError::DivisionByZero => write!(f, "Division by zero"),
406 VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
407 VmError::CategorizedError { message, category } => {
408 write!(f, "Error [{}]: {}", category.as_str(), message)
409 }
410 VmError::Return(_) => write!(f, "Return from function"),
411 VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
412 }
413 }
414}
415
416impl std::error::Error for VmError {}
417
418impl VmValue {
419 pub fn is_truthy(&self) -> bool {
420 match self {
421 VmValue::Bool(b) => *b,
422 VmValue::Nil => false,
423 VmValue::Int(n) => *n != 0,
424 VmValue::Float(n) => *n != 0.0,
425 VmValue::String(s) => !s.is_empty(),
426 VmValue::List(l) => !l.is_empty(),
427 VmValue::Dict(d) => !d.is_empty(),
428 VmValue::Closure(_) => true,
429 VmValue::Duration(ms) => *ms > 0,
430 VmValue::EnumVariant { .. } => true,
431 VmValue::StructInstance { .. } => true,
432 VmValue::TaskHandle(_) => true,
433 VmValue::Channel(_) => true,
434 VmValue::Atomic(_) => true,
435 VmValue::McpClient(_) => true,
436 VmValue::Set(s) => !s.is_empty(),
437 VmValue::Generator(_) => true,
438 }
439 }
440
441 pub fn type_name(&self) -> &'static str {
442 match self {
443 VmValue::String(_) => "string",
444 VmValue::Int(_) => "int",
445 VmValue::Float(_) => "float",
446 VmValue::Bool(_) => "bool",
447 VmValue::Nil => "nil",
448 VmValue::List(_) => "list",
449 VmValue::Dict(_) => "dict",
450 VmValue::Closure(_) => "closure",
451 VmValue::Duration(_) => "duration",
452 VmValue::EnumVariant { .. } => "enum",
453 VmValue::StructInstance { .. } => "struct",
454 VmValue::TaskHandle(_) => "task_handle",
455 VmValue::Channel(_) => "channel",
456 VmValue::Atomic(_) => "atomic",
457 VmValue::McpClient(_) => "mcp_client",
458 VmValue::Set(_) => "set",
459 VmValue::Generator(_) => "generator",
460 }
461 }
462
463 pub fn display(&self) -> String {
464 match self {
465 VmValue::Int(n) => n.to_string(),
466 VmValue::Float(n) => {
467 if *n == (*n as i64) as f64 && n.abs() < 1e15 {
468 format!("{:.1}", n)
469 } else {
470 n.to_string()
471 }
472 }
473 VmValue::String(s) => s.to_string(),
474 VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
475 VmValue::Nil => "nil".to_string(),
476 VmValue::List(items) => {
477 let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
478 format!("[{}]", inner.join(", "))
479 }
480 VmValue::Dict(map) => {
481 let inner: Vec<String> = map
482 .iter()
483 .map(|(k, v)| format!("{k}: {}", v.display()))
484 .collect();
485 format!("{{{}}}", inner.join(", "))
486 }
487 VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
488 VmValue::Duration(ms) => {
489 if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
490 format!("{}h", ms / 3_600_000)
491 } else if *ms >= 60_000 && ms % 60_000 == 0 {
492 format!("{}m", ms / 60_000)
493 } else if *ms >= 1000 && ms % 1000 == 0 {
494 format!("{}s", ms / 1000)
495 } else {
496 format!("{}ms", ms)
497 }
498 }
499 VmValue::EnumVariant {
500 enum_name,
501 variant,
502 fields,
503 } => {
504 if fields.is_empty() {
505 format!("{enum_name}.{variant}")
506 } else {
507 let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
508 format!("{enum_name}.{variant}({})", inner.join(", "))
509 }
510 }
511 VmValue::StructInstance {
512 struct_name,
513 fields,
514 } => {
515 let inner: Vec<String> = fields
516 .iter()
517 .map(|(k, v)| format!("{k}: {}", v.display()))
518 .collect();
519 format!("{struct_name} {{{}}}", inner.join(", "))
520 }
521 VmValue::TaskHandle(id) => format!("<task:{id}>"),
522 VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
523 VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
524 VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
525 VmValue::Set(items) => {
526 let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
527 format!("set({})", inner.join(", "))
528 }
529 VmValue::Generator(g) => {
530 if g.done.get() {
531 "<generator (done)>".to_string()
532 } else {
533 "<generator>".to_string()
534 }
535 }
536 }
537 }
538
539 pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
541 if let VmValue::Dict(d) = self {
542 Some(d)
543 } else {
544 None
545 }
546 }
547
548 pub fn as_int(&self) -> Option<i64> {
549 if let VmValue::Int(n) = self {
550 Some(*n)
551 } else {
552 None
553 }
554 }
555}
556
557pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
559
560pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
561 match (a, b) {
562 (VmValue::Int(x), VmValue::Int(y)) => x == y,
563 (VmValue::Float(x), VmValue::Float(y)) => x == y,
564 (VmValue::String(x), VmValue::String(y)) => x == y,
565 (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
566 (VmValue::Nil, VmValue::Nil) => true,
567 (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
568 (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
569 (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
570 (VmValue::Channel(_), VmValue::Channel(_)) => false, (VmValue::Atomic(a), VmValue::Atomic(b)) => {
572 a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
573 }
574 (VmValue::List(a), VmValue::List(b)) => {
575 a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
576 }
577 (VmValue::Dict(a), VmValue::Dict(b)) => {
578 a.len() == b.len()
579 && a.iter()
580 .zip(b.iter())
581 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
582 }
583 (
584 VmValue::EnumVariant {
585 enum_name: a_e,
586 variant: a_v,
587 fields: a_f,
588 },
589 VmValue::EnumVariant {
590 enum_name: b_e,
591 variant: b_v,
592 fields: b_f,
593 },
594 ) => {
595 a_e == b_e
596 && a_v == b_v
597 && a_f.len() == b_f.len()
598 && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
599 }
600 (
601 VmValue::StructInstance {
602 struct_name: a_s,
603 fields: a_f,
604 },
605 VmValue::StructInstance {
606 struct_name: b_s,
607 fields: b_f,
608 },
609 ) => {
610 a_s == b_s
611 && a_f.len() == b_f.len()
612 && a_f
613 .iter()
614 .zip(b_f.iter())
615 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
616 }
617 (VmValue::Set(a), VmValue::Set(b)) => {
618 a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
619 }
620 (VmValue::Generator(_), VmValue::Generator(_)) => false, _ => false,
622 }
623}
624
625pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
626 match (a, b) {
627 (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
628 (VmValue::Float(x), VmValue::Float(y)) => {
629 if x < y {
630 -1
631 } else if x > y {
632 1
633 } else {
634 0
635 }
636 }
637 (VmValue::Int(x), VmValue::Float(y)) => {
638 let x = *x as f64;
639 if x < *y {
640 -1
641 } else if x > *y {
642 1
643 } else {
644 0
645 }
646 }
647 (VmValue::Float(x), VmValue::Int(y)) => {
648 let y = *y as f64;
649 if *x < y {
650 -1
651 } else if *x > y {
652 1
653 } else {
654 0
655 }
656 }
657 (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
658 _ => 0,
659 }
660}