1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5
6use crate::chunk::CompiledFunction;
7use crate::mcp::VmMcpClientHandle;
8
9pub type VmAsyncBuiltinFn = Rc<
11 dyn Fn(
12 Vec<VmValue>,
13 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
14>;
15
16pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
18
19pub struct VmTaskHandle {
21 pub handle: VmJoinHandle,
22 pub cancel_token: Arc<AtomicBool>,
24}
25
26#[derive(Debug, Clone)]
28pub struct VmChannelHandle {
29 pub name: String,
30 pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
31 pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
32 pub closed: Arc<AtomicBool>,
33}
34
35#[derive(Debug, Clone)]
37pub struct VmAtomicHandle {
38 pub value: Arc<AtomicI64>,
39}
40
41#[derive(Debug, Clone)]
44pub struct VmGenerator {
45 pub done: Rc<std::cell::Cell<bool>>,
47 pub receiver: Rc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
51}
52
53#[derive(Debug, Clone)]
55pub enum VmValue {
56 Int(i64),
57 Float(f64),
58 String(Rc<str>),
59 Bool(bool),
60 Nil,
61 List(Rc<Vec<VmValue>>),
62 Dict(Rc<BTreeMap<String, VmValue>>),
63 Closure(Rc<VmClosure>),
64 Duration(u64),
65 EnumVariant {
66 enum_name: String,
67 variant: String,
68 fields: Vec<VmValue>,
69 },
70 StructInstance {
71 struct_name: String,
72 fields: BTreeMap<String, VmValue>,
73 },
74 TaskHandle(String),
75 Channel(VmChannelHandle),
76 Atomic(VmAtomicHandle),
77 McpClient(VmMcpClientHandle),
78 Set(Rc<Vec<VmValue>>),
79 Generator(VmGenerator),
80}
81
82#[derive(Debug, Clone)]
84pub struct VmClosure {
85 pub func: CompiledFunction,
86 pub env: VmEnv,
87}
88
89#[derive(Debug, Clone)]
91pub struct VmEnv {
92 pub(crate) scopes: Vec<Scope>,
93}
94
95#[derive(Debug, Clone)]
96pub(crate) struct Scope {
97 pub(crate) vars: BTreeMap<String, (VmValue, bool)>, }
99
100impl Default for VmEnv {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106impl VmEnv {
107 pub fn new() -> Self {
108 Self {
109 scopes: vec![Scope {
110 vars: BTreeMap::new(),
111 }],
112 }
113 }
114
115 pub fn push_scope(&mut self) {
116 self.scopes.push(Scope {
117 vars: BTreeMap::new(),
118 });
119 }
120
121 pub fn get(&self, name: &str) -> Option<VmValue> {
122 for scope in self.scopes.iter().rev() {
123 if let Some((val, _)) = scope.vars.get(name) {
124 return Some(val.clone());
125 }
126 }
127 None
128 }
129
130 pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) -> Result<(), VmError> {
131 if let Some(scope) = self.scopes.last_mut() {
132 if let Some((_, existing_mutable)) = scope.vars.get(name) {
133 if !existing_mutable && !mutable {
134 return Err(VmError::Runtime(format!(
135 "Cannot redeclare immutable variable '{name}' in the same scope (use 'var' for mutable bindings)"
136 )));
137 }
138 }
139 scope.vars.insert(name.to_string(), (value, mutable));
140 }
141 Ok(())
142 }
143
144 pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
145 let mut vars = BTreeMap::new();
146 for scope in &self.scopes {
147 for (name, (value, _)) in &scope.vars {
148 vars.insert(name.clone(), value.clone());
149 }
150 }
151 vars
152 }
153
154 pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
155 for scope in self.scopes.iter_mut().rev() {
156 if let Some((_, mutable)) = scope.vars.get(name) {
157 if !mutable {
158 return Err(VmError::ImmutableAssignment(name.to_string()));
159 }
160 scope.vars.insert(name.to_string(), (value, true));
161 return Ok(());
162 }
163 }
164 Err(VmError::UndefinedVariable(name.to_string()))
165 }
166}
167
168fn levenshtein(a: &str, b: &str) -> usize {
171 let a: Vec<char> = a.chars().collect();
172 let b: Vec<char> = b.chars().collect();
173 let (m, n) = (a.len(), b.len());
174 let mut prev = (0..=n).collect::<Vec<_>>();
175 let mut curr = vec![0; n + 1];
176 for i in 1..=m {
177 curr[0] = i;
178 for j in 1..=n {
179 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
180 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
181 }
182 std::mem::swap(&mut prev, &mut curr);
183 }
184 prev[n]
185}
186
187pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
190 let max_dist = match name.len() {
191 0..=2 => 1,
192 3..=5 => 2,
193 _ => 3,
194 };
195 candidates
196 .filter(|c| *c != name && !c.starts_with("__"))
197 .map(|c| (c, levenshtein(name, c)))
198 .filter(|(_, d)| *d <= max_dist)
199 .min_by(|(a, da), (b, db)| {
201 da.cmp(db)
202 .then_with(|| {
203 let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
204 let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
205 a_diff.cmp(&b_diff)
206 })
207 .then_with(|| a.cmp(b))
208 })
209 .map(|(c, _)| c.to_string())
210}
211
212#[derive(Debug, Clone)]
213pub enum VmError {
214 StackUnderflow,
215 StackOverflow,
216 UndefinedVariable(String),
217 UndefinedBuiltin(String),
218 ImmutableAssignment(String),
219 TypeError(String),
220 Runtime(String),
221 DivisionByZero,
222 Thrown(VmValue),
223 CategorizedError {
225 message: String,
226 category: ErrorCategory,
227 },
228 Return(VmValue),
229 InvalidInstruction(u8),
230}
231
232#[derive(Debug, Clone, PartialEq, Eq)]
234pub enum ErrorCategory {
235 Timeout,
237 Auth,
239 RateLimit,
241 ToolError,
243 ToolRejected,
245 Cancelled,
247 NotFound,
249 CircuitOpen,
251 Generic,
253}
254
255impl ErrorCategory {
256 pub fn as_str(&self) -> &'static str {
257 match self {
258 ErrorCategory::Timeout => "timeout",
259 ErrorCategory::Auth => "auth",
260 ErrorCategory::RateLimit => "rate_limit",
261 ErrorCategory::ToolError => "tool_error",
262 ErrorCategory::ToolRejected => "tool_rejected",
263 ErrorCategory::Cancelled => "cancelled",
264 ErrorCategory::NotFound => "not_found",
265 ErrorCategory::CircuitOpen => "circuit_open",
266 ErrorCategory::Generic => "generic",
267 }
268 }
269
270 pub fn parse(s: &str) -> Self {
271 match s {
272 "timeout" => ErrorCategory::Timeout,
273 "auth" => ErrorCategory::Auth,
274 "rate_limit" => ErrorCategory::RateLimit,
275 "tool_error" => ErrorCategory::ToolError,
276 "tool_rejected" => ErrorCategory::ToolRejected,
277 "cancelled" => ErrorCategory::Cancelled,
278 "not_found" => ErrorCategory::NotFound,
279 "circuit_open" => ErrorCategory::CircuitOpen,
280 _ => ErrorCategory::Generic,
281 }
282 }
283}
284
285pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
287 VmError::CategorizedError {
288 message: message.into(),
289 category,
290 }
291}
292
293pub fn error_to_category(err: &VmError) -> ErrorCategory {
302 match err {
303 VmError::CategorizedError { category, .. } => category.clone(),
304 VmError::Thrown(VmValue::Dict(d)) => d
305 .get("category")
306 .map(|v| ErrorCategory::parse(&v.display()))
307 .unwrap_or(ErrorCategory::Generic),
308 VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
309 VmError::Runtime(msg) => classify_error_message(msg),
310 _ => ErrorCategory::Generic,
311 }
312}
313
314fn classify_error_message(msg: &str) -> ErrorCategory {
317 if let Some(cat) = classify_by_http_status(msg) {
319 return cat;
320 }
321 if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
324 return ErrorCategory::Timeout;
325 }
326 if msg.contains("overloaded_error") || msg.contains("api_error") {
327 return ErrorCategory::RateLimit;
329 }
330 if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
331 return ErrorCategory::RateLimit;
333 }
334 if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
335 return ErrorCategory::Auth;
336 }
337 if msg.contains("not_found_error") || msg.contains("model_not_found") {
338 return ErrorCategory::NotFound;
339 }
340 if msg.contains("circuit_open") {
341 return ErrorCategory::CircuitOpen;
342 }
343 ErrorCategory::Generic
344}
345
346fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
350 for code in extract_http_status_codes(msg) {
353 return Some(match code {
354 401 | 403 => ErrorCategory::Auth,
355 404 | 410 => ErrorCategory::NotFound,
356 408 | 504 | 522 | 524 => ErrorCategory::Timeout,
357 429 | 503 => ErrorCategory::RateLimit,
358 _ => continue,
359 });
360 }
361 None
362}
363
364fn extract_http_status_codes(msg: &str) -> Vec<u16> {
366 let mut codes = Vec::new();
367 let bytes = msg.as_bytes();
368 for i in 0..bytes.len().saturating_sub(2) {
369 if bytes[i].is_ascii_digit()
371 && bytes[i + 1].is_ascii_digit()
372 && bytes[i + 2].is_ascii_digit()
373 {
374 let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
376 let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
377 if before_ok && after_ok {
378 if let Ok(code) = msg[i..i + 3].parse::<u16>() {
379 if (400..=599).contains(&code) {
380 codes.push(code);
381 }
382 }
383 }
384 }
385 }
386 codes
387}
388
389impl std::fmt::Display for VmError {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 match self {
392 VmError::StackUnderflow => write!(f, "Stack underflow"),
393 VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
394 VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
395 VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
396 VmError::ImmutableAssignment(n) => {
397 write!(f, "Cannot assign to immutable binding: {n}")
398 }
399 VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
400 VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
401 VmError::DivisionByZero => write!(f, "Division by zero"),
402 VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
403 VmError::CategorizedError { message, category } => {
404 write!(f, "Error [{}]: {}", category.as_str(), message)
405 }
406 VmError::Return(_) => write!(f, "Return from function"),
407 VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
408 }
409 }
410}
411
412impl std::error::Error for VmError {}
413
414impl VmValue {
415 pub fn is_truthy(&self) -> bool {
416 match self {
417 VmValue::Bool(b) => *b,
418 VmValue::Nil => false,
419 VmValue::Int(n) => *n != 0,
420 VmValue::Float(n) => *n != 0.0,
421 VmValue::String(s) => !s.is_empty(),
422 VmValue::List(l) => !l.is_empty(),
423 VmValue::Dict(d) => !d.is_empty(),
424 VmValue::Closure(_) => true,
425 VmValue::Duration(ms) => *ms > 0,
426 VmValue::EnumVariant { .. } => true,
427 VmValue::StructInstance { .. } => true,
428 VmValue::TaskHandle(_) => true,
429 VmValue::Channel(_) => true,
430 VmValue::Atomic(_) => true,
431 VmValue::McpClient(_) => true,
432 VmValue::Set(s) => !s.is_empty(),
433 VmValue::Generator(_) => true,
434 }
435 }
436
437 pub fn type_name(&self) -> &'static str {
438 match self {
439 VmValue::String(_) => "string",
440 VmValue::Int(_) => "int",
441 VmValue::Float(_) => "float",
442 VmValue::Bool(_) => "bool",
443 VmValue::Nil => "nil",
444 VmValue::List(_) => "list",
445 VmValue::Dict(_) => "dict",
446 VmValue::Closure(_) => "closure",
447 VmValue::Duration(_) => "duration",
448 VmValue::EnumVariant { .. } => "enum",
449 VmValue::StructInstance { .. } => "struct",
450 VmValue::TaskHandle(_) => "task_handle",
451 VmValue::Channel(_) => "channel",
452 VmValue::Atomic(_) => "atomic",
453 VmValue::McpClient(_) => "mcp_client",
454 VmValue::Set(_) => "set",
455 VmValue::Generator(_) => "generator",
456 }
457 }
458
459 pub fn display(&self) -> String {
460 match self {
461 VmValue::Int(n) => n.to_string(),
462 VmValue::Float(n) => {
463 if *n == (*n as i64) as f64 && n.abs() < 1e15 {
464 format!("{:.1}", n)
465 } else {
466 n.to_string()
467 }
468 }
469 VmValue::String(s) => s.to_string(),
470 VmValue::Bool(b) => (if *b { "true" } else { "false" }).to_string(),
471 VmValue::Nil => "nil".to_string(),
472 VmValue::List(items) => {
473 let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
474 format!("[{}]", inner.join(", "))
475 }
476 VmValue::Dict(map) => {
477 let inner: Vec<String> = map
478 .iter()
479 .map(|(k, v)| format!("{k}: {}", v.display()))
480 .collect();
481 format!("{{{}}}", inner.join(", "))
482 }
483 VmValue::Closure(c) => format!("<fn({})>", c.func.params.join(", ")),
484 VmValue::Duration(ms) => {
485 if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
486 format!("{}h", ms / 3_600_000)
487 } else if *ms >= 60_000 && ms % 60_000 == 0 {
488 format!("{}m", ms / 60_000)
489 } else if *ms >= 1000 && ms % 1000 == 0 {
490 format!("{}s", ms / 1000)
491 } else {
492 format!("{}ms", ms)
493 }
494 }
495 VmValue::EnumVariant {
496 enum_name,
497 variant,
498 fields,
499 } => {
500 if fields.is_empty() {
501 format!("{enum_name}.{variant}")
502 } else {
503 let inner: Vec<String> = fields.iter().map(|v| v.display()).collect();
504 format!("{enum_name}.{variant}({})", inner.join(", "))
505 }
506 }
507 VmValue::StructInstance {
508 struct_name,
509 fields,
510 } => {
511 let inner: Vec<String> = fields
512 .iter()
513 .map(|(k, v)| format!("{k}: {}", v.display()))
514 .collect();
515 format!("{struct_name} {{{}}}", inner.join(", "))
516 }
517 VmValue::TaskHandle(id) => format!("<task:{id}>"),
518 VmValue::Channel(ch) => format!("<channel:{}>", ch.name),
519 VmValue::Atomic(a) => format!("<atomic:{}>", a.value.load(Ordering::SeqCst)),
520 VmValue::McpClient(c) => format!("<mcp_client:{}>", c.name),
521 VmValue::Set(items) => {
522 let inner: Vec<String> = items.iter().map(|i| i.display()).collect();
523 format!("set({})", inner.join(", "))
524 }
525 VmValue::Generator(g) => {
526 if g.done.get() {
527 "<generator (done)>".to_string()
528 } else {
529 "<generator>".to_string()
530 }
531 }
532 }
533 }
534
535 pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
537 if let VmValue::Dict(d) = self {
538 Some(d)
539 } else {
540 None
541 }
542 }
543
544 pub fn as_int(&self) -> Option<i64> {
545 if let VmValue::Int(n) = self {
546 Some(*n)
547 } else {
548 None
549 }
550 }
551}
552
553pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
555
556pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
557 match (a, b) {
558 (VmValue::Int(x), VmValue::Int(y)) => x == y,
559 (VmValue::Float(x), VmValue::Float(y)) => x == y,
560 (VmValue::String(x), VmValue::String(y)) => x == y,
561 (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
562 (VmValue::Nil, VmValue::Nil) => true,
563 (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
564 (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
565 (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
566 (VmValue::Channel(_), VmValue::Channel(_)) => false, (VmValue::Atomic(a), VmValue::Atomic(b)) => {
568 a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
569 }
570 (VmValue::List(a), VmValue::List(b)) => {
571 a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
572 }
573 (VmValue::Dict(a), VmValue::Dict(b)) => {
574 a.len() == b.len()
575 && a.iter()
576 .zip(b.iter())
577 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
578 }
579 (
580 VmValue::EnumVariant {
581 enum_name: a_e,
582 variant: a_v,
583 fields: a_f,
584 },
585 VmValue::EnumVariant {
586 enum_name: b_e,
587 variant: b_v,
588 fields: b_f,
589 },
590 ) => {
591 a_e == b_e
592 && a_v == b_v
593 && a_f.len() == b_f.len()
594 && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
595 }
596 (
597 VmValue::StructInstance {
598 struct_name: a_s,
599 fields: a_f,
600 },
601 VmValue::StructInstance {
602 struct_name: b_s,
603 fields: b_f,
604 },
605 ) => {
606 a_s == b_s
607 && a_f.len() == b_f.len()
608 && a_f
609 .iter()
610 .zip(b_f.iter())
611 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
612 }
613 (VmValue::Set(a), VmValue::Set(b)) => {
614 a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
615 }
616 (VmValue::Generator(_), VmValue::Generator(_)) => false, _ => false,
618 }
619}
620
621pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
622 match (a, b) {
623 (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
624 (VmValue::Float(x), VmValue::Float(y)) => {
625 if x < y {
626 -1
627 } else if x > y {
628 1
629 } else {
630 0
631 }
632 }
633 (VmValue::Int(x), VmValue::Float(y)) => {
634 let x = *x as f64;
635 if x < *y {
636 -1
637 } else if x > *y {
638 1
639 } else {
640 0
641 }
642 }
643 (VmValue::Float(x), VmValue::Int(y)) => {
644 let y = *y as f64;
645 if *x < y {
646 -1
647 } else if *x > y {
648 1
649 } else {
650 0
651 }
652 }
653 (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
654 _ => 0,
655 }
656}