1use std::collections::BTreeMap;
2use std::rc::Rc;
3use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
4use std::sync::Arc;
5use std::{cell::RefCell, path::PathBuf};
6
7use crate::chunk::CompiledFunction;
8use crate::mcp::VmMcpClientHandle;
9
10pub type VmAsyncBuiltinFn = Rc<
12 dyn Fn(
13 Vec<VmValue>,
14 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<VmValue, VmError>>>>,
15>;
16
17pub type VmJoinHandle = tokio::task::JoinHandle<Result<(VmValue, String), VmError>>;
19
20pub struct VmTaskHandle {
22 pub handle: VmJoinHandle,
23 pub cancel_token: Arc<AtomicBool>,
25}
26
27#[derive(Debug, Clone)]
29pub struct VmChannelHandle {
30 pub name: String,
31 pub sender: Arc<tokio::sync::mpsc::Sender<VmValue>>,
32 pub receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
33 pub closed: Arc<AtomicBool>,
34}
35
36#[derive(Debug, Clone)]
38pub struct VmAtomicHandle {
39 pub value: Arc<AtomicI64>,
40}
41
42#[derive(Debug, Clone)]
45pub struct VmGenerator {
46 pub done: Rc<std::cell::Cell<bool>>,
48 pub receiver: Rc<tokio::sync::Mutex<tokio::sync::mpsc::Receiver<VmValue>>>,
52}
53
54#[derive(Debug, Clone)]
56pub enum VmValue {
57 Int(i64),
58 Float(f64),
59 String(Rc<str>),
60 Bool(bool),
61 Nil,
62 List(Rc<Vec<VmValue>>),
63 Dict(Rc<BTreeMap<String, VmValue>>),
64 Closure(Rc<VmClosure>),
65 BuiltinRef(Rc<str>),
69 Duration(u64),
70 EnumVariant {
71 enum_name: String,
72 variant: String,
73 fields: Vec<VmValue>,
74 },
75 StructInstance {
76 struct_name: String,
77 fields: BTreeMap<String, VmValue>,
78 },
79 TaskHandle(String),
80 Channel(VmChannelHandle),
81 Atomic(VmAtomicHandle),
82 McpClient(VmMcpClientHandle),
83 Set(Rc<Vec<VmValue>>),
84 Generator(VmGenerator),
85}
86
87#[derive(Debug, Clone)]
89pub struct VmClosure {
90 pub func: CompiledFunction,
91 pub env: VmEnv,
92 pub source_dir: Option<PathBuf>,
96 pub module_functions: Option<ModuleFunctionRegistry>,
100 pub module_state: Option<ModuleState>,
113}
114
115pub type ModuleFunctionRegistry = Rc<RefCell<BTreeMap<String, Rc<VmClosure>>>>;
116pub type ModuleState = Rc<RefCell<VmEnv>>;
117
118#[derive(Debug, Clone)]
120pub struct VmEnv {
121 pub(crate) scopes: Vec<Scope>,
122}
123
124#[derive(Debug, Clone)]
125pub(crate) struct Scope {
126 pub(crate) vars: BTreeMap<String, (VmValue, bool)>, }
128
129impl Default for VmEnv {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135impl VmEnv {
136 pub fn new() -> Self {
137 Self {
138 scopes: vec![Scope {
139 vars: BTreeMap::new(),
140 }],
141 }
142 }
143
144 pub fn push_scope(&mut self) {
145 self.scopes.push(Scope {
146 vars: BTreeMap::new(),
147 });
148 }
149
150 pub fn pop_scope(&mut self) {
151 if self.scopes.len() > 1 {
152 self.scopes.pop();
153 }
154 }
155
156 pub fn scope_depth(&self) -> usize {
157 self.scopes.len()
158 }
159
160 pub fn truncate_scopes(&mut self, target_depth: usize) {
161 let min_depth = target_depth.max(1);
162 while self.scopes.len() > min_depth {
163 self.scopes.pop();
164 }
165 }
166
167 pub fn get(&self, name: &str) -> Option<VmValue> {
168 for scope in self.scopes.iter().rev() {
169 if let Some((val, _)) = scope.vars.get(name) {
170 return Some(val.clone());
171 }
172 }
173 None
174 }
175
176 pub fn define(&mut self, name: &str, value: VmValue, mutable: bool) -> Result<(), VmError> {
177 if let Some(scope) = self.scopes.last_mut() {
178 if let Some((_, existing_mutable)) = scope.vars.get(name) {
179 if !existing_mutable && !mutable {
180 return Err(VmError::Runtime(format!(
181 "Cannot redeclare immutable variable '{name}' in the same scope (use 'var' for mutable bindings)"
182 )));
183 }
184 }
185 scope.vars.insert(name.to_string(), (value, mutable));
186 }
187 Ok(())
188 }
189
190 pub fn all_variables(&self) -> BTreeMap<String, VmValue> {
191 let mut vars = BTreeMap::new();
192 for scope in &self.scopes {
193 for (name, (value, _)) in &scope.vars {
194 vars.insert(name.clone(), value.clone());
195 }
196 }
197 vars
198 }
199
200 pub fn assign(&mut self, name: &str, value: VmValue) -> Result<(), VmError> {
201 for scope in self.scopes.iter_mut().rev() {
202 if let Some((_, mutable)) = scope.vars.get(name) {
203 if !mutable {
204 return Err(VmError::ImmutableAssignment(name.to_string()));
205 }
206 scope.vars.insert(name.to_string(), (value, true));
207 return Ok(());
208 }
209 }
210 Err(VmError::UndefinedVariable(name.to_string()))
211 }
212}
213
214fn levenshtein(a: &str, b: &str) -> usize {
217 let a: Vec<char> = a.chars().collect();
218 let b: Vec<char> = b.chars().collect();
219 let (m, n) = (a.len(), b.len());
220 let mut prev = (0..=n).collect::<Vec<_>>();
221 let mut curr = vec![0; n + 1];
222 for i in 1..=m {
223 curr[0] = i;
224 for j in 1..=n {
225 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
226 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
227 }
228 std::mem::swap(&mut prev, &mut curr);
229 }
230 prev[n]
231}
232
233pub fn closest_match<'a>(name: &str, candidates: impl Iterator<Item = &'a str>) -> Option<String> {
236 let max_dist = match name.len() {
237 0..=2 => 1,
238 3..=5 => 2,
239 _ => 3,
240 };
241 candidates
242 .filter(|c| *c != name && !c.starts_with("__"))
243 .map(|c| (c, levenshtein(name, c)))
244 .filter(|(_, d)| *d <= max_dist)
245 .min_by(|(a, da), (b, db)| {
247 da.cmp(db)
248 .then_with(|| {
249 let a_diff = (a.len() as isize - name.len() as isize).unsigned_abs();
250 let b_diff = (b.len() as isize - name.len() as isize).unsigned_abs();
251 a_diff.cmp(&b_diff)
252 })
253 .then_with(|| a.cmp(b))
254 })
255 .map(|(c, _)| c.to_string())
256}
257
258#[derive(Debug, Clone)]
259pub enum VmError {
260 StackUnderflow,
261 StackOverflow,
262 UndefinedVariable(String),
263 UndefinedBuiltin(String),
264 ImmutableAssignment(String),
265 TypeError(String),
266 Runtime(String),
267 DivisionByZero,
268 Thrown(VmValue),
269 CategorizedError {
271 message: String,
272 category: ErrorCategory,
273 },
274 Return(VmValue),
275 InvalidInstruction(u8),
276}
277
278#[derive(Debug, Clone, PartialEq, Eq)]
280pub enum ErrorCategory {
281 Timeout,
283 Auth,
285 RateLimit,
287 ToolError,
289 ToolRejected,
291 Cancelled,
293 NotFound,
295 CircuitOpen,
297 Generic,
299}
300
301impl ErrorCategory {
302 pub fn as_str(&self) -> &'static str {
303 match self {
304 ErrorCategory::Timeout => "timeout",
305 ErrorCategory::Auth => "auth",
306 ErrorCategory::RateLimit => "rate_limit",
307 ErrorCategory::ToolError => "tool_error",
308 ErrorCategory::ToolRejected => "tool_rejected",
309 ErrorCategory::Cancelled => "cancelled",
310 ErrorCategory::NotFound => "not_found",
311 ErrorCategory::CircuitOpen => "circuit_open",
312 ErrorCategory::Generic => "generic",
313 }
314 }
315
316 pub fn parse(s: &str) -> Self {
317 match s {
318 "timeout" => ErrorCategory::Timeout,
319 "auth" => ErrorCategory::Auth,
320 "rate_limit" => ErrorCategory::RateLimit,
321 "tool_error" => ErrorCategory::ToolError,
322 "tool_rejected" => ErrorCategory::ToolRejected,
323 "cancelled" => ErrorCategory::Cancelled,
324 "not_found" => ErrorCategory::NotFound,
325 "circuit_open" => ErrorCategory::CircuitOpen,
326 _ => ErrorCategory::Generic,
327 }
328 }
329}
330
331pub fn categorized_error(message: impl Into<String>, category: ErrorCategory) -> VmError {
333 VmError::CategorizedError {
334 message: message.into(),
335 category,
336 }
337}
338
339pub fn error_to_category(err: &VmError) -> ErrorCategory {
348 match err {
349 VmError::CategorizedError { category, .. } => category.clone(),
350 VmError::Thrown(VmValue::Dict(d)) => d
351 .get("category")
352 .map(|v| ErrorCategory::parse(&v.display()))
353 .unwrap_or(ErrorCategory::Generic),
354 VmError::Thrown(VmValue::String(s)) => classify_error_message(s),
355 VmError::Runtime(msg) => classify_error_message(msg),
356 _ => ErrorCategory::Generic,
357 }
358}
359
360fn classify_error_message(msg: &str) -> ErrorCategory {
363 if let Some(cat) = classify_by_http_status(msg) {
365 return cat;
366 }
367 if msg.contains("Deadline exceeded") || msg.contains("context deadline exceeded") {
370 return ErrorCategory::Timeout;
371 }
372 if msg.contains("overloaded_error") || msg.contains("api_error") {
373 return ErrorCategory::RateLimit;
375 }
376 if msg.contains("insufficient_quota") || msg.contains("billing_hard_limit_reached") {
377 return ErrorCategory::RateLimit;
379 }
380 if msg.contains("invalid_api_key") || msg.contains("authentication_error") {
381 return ErrorCategory::Auth;
382 }
383 if msg.contains("not_found_error") || msg.contains("model_not_found") {
384 return ErrorCategory::NotFound;
385 }
386 if msg.contains("circuit_open") {
387 return ErrorCategory::CircuitOpen;
388 }
389 ErrorCategory::Generic
390}
391
392fn classify_by_http_status(msg: &str) -> Option<ErrorCategory> {
396 for code in extract_http_status_codes(msg) {
399 return Some(match code {
400 401 | 403 => ErrorCategory::Auth,
401 404 | 410 => ErrorCategory::NotFound,
402 408 | 504 | 522 | 524 => ErrorCategory::Timeout,
403 429 | 503 => ErrorCategory::RateLimit,
404 _ => continue,
405 });
406 }
407 None
408}
409
410fn extract_http_status_codes(msg: &str) -> Vec<u16> {
412 let mut codes = Vec::new();
413 let bytes = msg.as_bytes();
414 for i in 0..bytes.len().saturating_sub(2) {
415 if bytes[i].is_ascii_digit()
417 && bytes[i + 1].is_ascii_digit()
418 && bytes[i + 2].is_ascii_digit()
419 {
420 let before_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
422 let after_ok = i + 3 >= bytes.len() || !bytes[i + 3].is_ascii_digit();
423 if before_ok && after_ok {
424 if let Ok(code) = msg[i..i + 3].parse::<u16>() {
425 if (400..=599).contains(&code) {
426 codes.push(code);
427 }
428 }
429 }
430 }
431 }
432 codes
433}
434
435impl std::fmt::Display for VmError {
436 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437 match self {
438 VmError::StackUnderflow => write!(f, "Stack underflow"),
439 VmError::StackOverflow => write!(f, "Stack overflow: too many nested calls"),
440 VmError::UndefinedVariable(n) => write!(f, "Undefined variable: {n}"),
441 VmError::UndefinedBuiltin(n) => write!(f, "Undefined builtin: {n}"),
442 VmError::ImmutableAssignment(n) => {
443 write!(f, "Cannot assign to immutable binding: {n}")
444 }
445 VmError::TypeError(msg) => write!(f, "Type error: {msg}"),
446 VmError::Runtime(msg) => write!(f, "Runtime error: {msg}"),
447 VmError::DivisionByZero => write!(f, "Division by zero"),
448 VmError::Thrown(v) => write!(f, "Thrown: {}", v.display()),
449 VmError::CategorizedError { message, category } => {
450 write!(f, "Error [{}]: {}", category.as_str(), message)
451 }
452 VmError::Return(_) => write!(f, "Return from function"),
453 VmError::InvalidInstruction(op) => write!(f, "Invalid instruction: 0x{op:02x}"),
454 }
455 }
456}
457
458impl std::error::Error for VmError {}
459
460impl VmValue {
461 pub fn is_truthy(&self) -> bool {
462 match self {
463 VmValue::Bool(b) => *b,
464 VmValue::Nil => false,
465 VmValue::Int(n) => *n != 0,
466 VmValue::Float(n) => *n != 0.0,
467 VmValue::String(s) => !s.is_empty(),
468 VmValue::List(l) => !l.is_empty(),
469 VmValue::Dict(d) => !d.is_empty(),
470 VmValue::Closure(_) => true,
471 VmValue::BuiltinRef(_) => true,
472 VmValue::Duration(ms) => *ms > 0,
473 VmValue::EnumVariant { .. } => true,
474 VmValue::StructInstance { .. } => true,
475 VmValue::TaskHandle(_) => true,
476 VmValue::Channel(_) => true,
477 VmValue::Atomic(_) => true,
478 VmValue::McpClient(_) => true,
479 VmValue::Set(s) => !s.is_empty(),
480 VmValue::Generator(_) => true,
481 }
482 }
483
484 pub fn type_name(&self) -> &'static str {
485 match self {
486 VmValue::String(_) => "string",
487 VmValue::Int(_) => "int",
488 VmValue::Float(_) => "float",
489 VmValue::Bool(_) => "bool",
490 VmValue::Nil => "nil",
491 VmValue::List(_) => "list",
492 VmValue::Dict(_) => "dict",
493 VmValue::Closure(_) => "closure",
494 VmValue::BuiltinRef(_) => "builtin",
495 VmValue::Duration(_) => "duration",
496 VmValue::EnumVariant { .. } => "enum",
497 VmValue::StructInstance { .. } => "struct",
498 VmValue::TaskHandle(_) => "task_handle",
499 VmValue::Channel(_) => "channel",
500 VmValue::Atomic(_) => "atomic",
501 VmValue::McpClient(_) => "mcp_client",
502 VmValue::Set(_) => "set",
503 VmValue::Generator(_) => "generator",
504 }
505 }
506
507 pub fn display(&self) -> String {
508 let mut out = String::new();
509 self.write_display(&mut out);
510 out
511 }
512
513 pub fn write_display(&self, out: &mut String) {
516 use std::fmt::Write;
517 match self {
518 VmValue::Int(n) => {
519 let _ = write!(out, "{n}");
520 }
521 VmValue::Float(n) => {
522 if *n == (*n as i64) as f64 && n.abs() < 1e15 {
523 let _ = write!(out, "{n:.1}");
524 } else {
525 let _ = write!(out, "{n}");
526 }
527 }
528 VmValue::String(s) => out.push_str(s),
529 VmValue::Bool(b) => out.push_str(if *b { "true" } else { "false" }),
530 VmValue::Nil => out.push_str("nil"),
531 VmValue::List(items) => {
532 out.push('[');
533 for (i, item) in items.iter().enumerate() {
534 if i > 0 {
535 out.push_str(", ");
536 }
537 item.write_display(out);
538 }
539 out.push(']');
540 }
541 VmValue::Dict(map) => {
542 out.push('{');
543 for (i, (k, v)) in map.iter().enumerate() {
544 if i > 0 {
545 out.push_str(", ");
546 }
547 out.push_str(k);
548 out.push_str(": ");
549 v.write_display(out);
550 }
551 out.push('}');
552 }
553 VmValue::Closure(c) => {
554 let _ = write!(out, "<fn({})>", c.func.params.join(", "));
555 }
556 VmValue::BuiltinRef(name) => {
557 let _ = write!(out, "<builtin {name}>");
558 }
559 VmValue::Duration(ms) => {
560 if *ms >= 3_600_000 && ms % 3_600_000 == 0 {
561 let _ = write!(out, "{}h", ms / 3_600_000);
562 } else if *ms >= 60_000 && ms % 60_000 == 0 {
563 let _ = write!(out, "{}m", ms / 60_000);
564 } else if *ms >= 1000 && ms % 1000 == 0 {
565 let _ = write!(out, "{}s", ms / 1000);
566 } else {
567 let _ = write!(out, "{}ms", ms);
568 }
569 }
570 VmValue::EnumVariant {
571 enum_name,
572 variant,
573 fields,
574 } => {
575 if fields.is_empty() {
576 let _ = write!(out, "{enum_name}.{variant}");
577 } else {
578 let _ = write!(out, "{enum_name}.{variant}(");
579 for (i, v) in fields.iter().enumerate() {
580 if i > 0 {
581 out.push_str(", ");
582 }
583 v.write_display(out);
584 }
585 out.push(')');
586 }
587 }
588 VmValue::StructInstance {
589 struct_name,
590 fields,
591 } => {
592 let _ = write!(out, "{struct_name} {{");
593 for (i, (k, v)) in fields.iter().enumerate() {
594 if i > 0 {
595 out.push_str(", ");
596 }
597 out.push_str(k);
598 out.push_str(": ");
599 v.write_display(out);
600 }
601 out.push('}');
602 }
603 VmValue::TaskHandle(id) => {
604 let _ = write!(out, "<task:{id}>");
605 }
606 VmValue::Channel(ch) => {
607 let _ = write!(out, "<channel:{}>", ch.name);
608 }
609 VmValue::Atomic(a) => {
610 let _ = write!(out, "<atomic:{}>", a.value.load(Ordering::SeqCst));
611 }
612 VmValue::McpClient(c) => {
613 let _ = write!(out, "<mcp_client:{}>", c.name);
614 }
615 VmValue::Set(items) => {
616 out.push_str("set(");
617 for (i, item) in items.iter().enumerate() {
618 if i > 0 {
619 out.push_str(", ");
620 }
621 item.write_display(out);
622 }
623 out.push(')');
624 }
625 VmValue::Generator(g) => {
626 if g.done.get() {
627 out.push_str("<generator (done)>");
628 } else {
629 out.push_str("<generator>");
630 }
631 }
632 }
633 }
634
635 pub fn as_dict(&self) -> Option<&BTreeMap<String, VmValue>> {
637 if let VmValue::Dict(d) = self {
638 Some(d)
639 } else {
640 None
641 }
642 }
643
644 pub fn as_int(&self) -> Option<i64> {
645 if let VmValue::Int(n) = self {
646 Some(*n)
647 } else {
648 None
649 }
650 }
651}
652
653pub type VmBuiltinFn = Rc<dyn Fn(&[VmValue], &mut String) -> Result<VmValue, VmError>>;
655
656pub fn values_identical(a: &VmValue, b: &VmValue) -> bool {
661 match (a, b) {
662 (VmValue::List(x), VmValue::List(y)) => Rc::ptr_eq(x, y),
663 (VmValue::Dict(x), VmValue::Dict(y)) => Rc::ptr_eq(x, y),
664 (VmValue::Set(x), VmValue::Set(y)) => Rc::ptr_eq(x, y),
665 (VmValue::Closure(x), VmValue::Closure(y)) => Rc::ptr_eq(x, y),
666 (VmValue::String(x), VmValue::String(y)) => Rc::ptr_eq(x, y) || x == y,
667 (VmValue::BuiltinRef(x), VmValue::BuiltinRef(y)) => x == y,
668 _ => values_equal(a, b),
670 }
671}
672
673pub fn value_identity_key(v: &VmValue) -> String {
678 match v {
679 VmValue::List(x) => format!("list@{:p}", Rc::as_ptr(x)),
680 VmValue::Dict(x) => format!("dict@{:p}", Rc::as_ptr(x)),
681 VmValue::Set(x) => format!("set@{:p}", Rc::as_ptr(x)),
682 VmValue::Closure(x) => format!("closure@{:p}", Rc::as_ptr(x)),
683 VmValue::String(x) => format!("string@{:p}", x.as_ptr()),
684 VmValue::BuiltinRef(name) => format!("builtin@{name}"),
685 other => format!("{}@{}", other.type_name(), other.display()),
686 }
687}
688
689pub fn value_structural_hash_key(v: &VmValue) -> String {
695 let mut out = String::new();
696 write_structural_hash_key(v, &mut out);
697 out
698}
699
700fn write_structural_hash_key(v: &VmValue, out: &mut String) {
704 match v {
705 VmValue::Nil => out.push('N'),
706 VmValue::Bool(b) => {
707 out.push(if *b { 'T' } else { 'F' });
708 }
709 VmValue::Int(n) => {
710 out.push('i');
711 out.push_str(&n.to_string());
712 out.push(';');
713 }
714 VmValue::Float(n) => {
715 out.push('f');
716 out.push_str(&n.to_bits().to_string());
717 out.push(';');
718 }
719 VmValue::String(s) => {
720 out.push('s');
722 out.push_str(&s.len().to_string());
723 out.push(':');
724 out.push_str(s);
725 }
726 VmValue::Duration(ms) => {
727 out.push('d');
728 out.push_str(&ms.to_string());
729 out.push(';');
730 }
731 VmValue::List(items) => {
732 out.push('L');
733 for item in items.iter() {
734 write_structural_hash_key(item, out);
735 out.push(',');
736 }
737 out.push(']');
738 }
739 VmValue::Dict(map) => {
740 out.push('D');
741 for (k, v) in map.iter() {
742 out.push_str(&k.len().to_string());
744 out.push(':');
745 out.push_str(k);
746 out.push('=');
747 write_structural_hash_key(v, out);
748 out.push(',');
749 }
750 out.push('}');
751 }
752 VmValue::Set(items) => {
753 let mut keys: Vec<String> = items.iter().map(value_structural_hash_key).collect();
755 keys.sort();
756 out.push('S');
757 for k in &keys {
758 out.push_str(k);
759 out.push(',');
760 }
761 out.push('}');
762 }
763 other => {
764 let tn = other.type_name();
765 out.push('o');
766 out.push_str(&tn.len().to_string());
767 out.push(':');
768 out.push_str(tn);
769 let d = other.display();
770 out.push_str(&d.len().to_string());
771 out.push(':');
772 out.push_str(&d);
773 }
774 }
775}
776
777pub fn values_equal(a: &VmValue, b: &VmValue) -> bool {
778 match (a, b) {
779 (VmValue::Int(x), VmValue::Int(y)) => x == y,
780 (VmValue::Float(x), VmValue::Float(y)) => x == y,
781 (VmValue::String(x), VmValue::String(y)) => x == y,
782 (VmValue::Bool(x), VmValue::Bool(y)) => x == y,
783 (VmValue::Nil, VmValue::Nil) => true,
784 (VmValue::Int(x), VmValue::Float(y)) => (*x as f64) == *y,
785 (VmValue::Float(x), VmValue::Int(y)) => *x == (*y as f64),
786 (VmValue::TaskHandle(a), VmValue::TaskHandle(b)) => a == b,
787 (VmValue::Channel(_), VmValue::Channel(_)) => false, (VmValue::Atomic(a), VmValue::Atomic(b)) => {
789 a.value.load(Ordering::SeqCst) == b.value.load(Ordering::SeqCst)
790 }
791 (VmValue::List(a), VmValue::List(b)) => {
792 a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| values_equal(x, y))
793 }
794 (VmValue::Dict(a), VmValue::Dict(b)) => {
795 a.len() == b.len()
796 && a.iter()
797 .zip(b.iter())
798 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
799 }
800 (
801 VmValue::EnumVariant {
802 enum_name: a_e,
803 variant: a_v,
804 fields: a_f,
805 },
806 VmValue::EnumVariant {
807 enum_name: b_e,
808 variant: b_v,
809 fields: b_f,
810 },
811 ) => {
812 a_e == b_e
813 && a_v == b_v
814 && a_f.len() == b_f.len()
815 && a_f.iter().zip(b_f.iter()).all(|(x, y)| values_equal(x, y))
816 }
817 (
818 VmValue::StructInstance {
819 struct_name: a_s,
820 fields: a_f,
821 },
822 VmValue::StructInstance {
823 struct_name: b_s,
824 fields: b_f,
825 },
826 ) => {
827 a_s == b_s
828 && a_f.len() == b_f.len()
829 && a_f
830 .iter()
831 .zip(b_f.iter())
832 .all(|((k1, v1), (k2, v2))| k1 == k2 && values_equal(v1, v2))
833 }
834 (VmValue::Set(a), VmValue::Set(b)) => {
835 a.len() == b.len() && a.iter().all(|x| b.iter().any(|y| values_equal(x, y)))
836 }
837 (VmValue::Generator(_), VmValue::Generator(_)) => false, _ => false,
839 }
840}
841
842pub fn compare_values(a: &VmValue, b: &VmValue) -> i32 {
843 match (a, b) {
844 (VmValue::Int(x), VmValue::Int(y)) => x.cmp(y) as i32,
845 (VmValue::Float(x), VmValue::Float(y)) => {
846 if x < y {
847 -1
848 } else if x > y {
849 1
850 } else {
851 0
852 }
853 }
854 (VmValue::Int(x), VmValue::Float(y)) => {
855 let x = *x as f64;
856 if x < *y {
857 -1
858 } else if x > *y {
859 1
860 } else {
861 0
862 }
863 }
864 (VmValue::Float(x), VmValue::Int(y)) => {
865 let y = *y as f64;
866 if *x < y {
867 -1
868 } else if *x > y {
869 1
870 } else {
871 0
872 }
873 }
874 (VmValue::String(x), VmValue::String(y)) => x.cmp(y) as i32,
875 _ => 0,
876 }
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882
883 fn s(val: &str) -> VmValue {
884 VmValue::String(Rc::from(val))
885 }
886 fn i(val: i64) -> VmValue {
887 VmValue::Int(val)
888 }
889 fn list(items: Vec<VmValue>) -> VmValue {
890 VmValue::List(Rc::new(items))
891 }
892 fn dict(pairs: Vec<(&str, VmValue)>) -> VmValue {
893 VmValue::Dict(Rc::new(
894 pairs.into_iter().map(|(k, v)| (k.to_string(), v)).collect(),
895 ))
896 }
897
898 #[test]
899 fn hash_key_cross_type_distinct() {
900 let k_int = value_structural_hash_key(&i(1));
902 let k_str = value_structural_hash_key(&s("1"));
903 let k_bool = value_structural_hash_key(&VmValue::Bool(true));
904 assert_ne!(k_int, k_str);
905 assert_ne!(k_int, k_bool);
906 assert_ne!(k_str, k_bool);
907 }
908
909 #[test]
910 fn hash_key_string_with_separator_chars() {
911 let one_elem = list(vec![s("a,string:b")]);
913 let two_elem = list(vec![s("a"), s("b")]);
914 assert_ne!(
915 value_structural_hash_key(&one_elem),
916 value_structural_hash_key(&two_elem),
917 "length-prefixed strings must prevent separator collisions"
918 );
919 }
920
921 #[test]
922 fn hash_key_dict_key_with_equals() {
923 let d1 = dict(vec![("a=b", i(1))]);
925 let d2 = dict(vec![("a", i(1))]);
926 assert_ne!(
927 value_structural_hash_key(&d1),
928 value_structural_hash_key(&d2)
929 );
930 }
931
932 #[test]
933 fn hash_key_nested_list_vs_flat() {
934 let nested = list(vec![list(vec![i(1)])]);
936 let flat = list(vec![i(1)]);
937 assert_ne!(
938 value_structural_hash_key(&nested),
939 value_structural_hash_key(&flat)
940 );
941 }
942
943 #[test]
944 fn hash_key_nil() {
945 assert_eq!(
946 value_structural_hash_key(&VmValue::Nil),
947 value_structural_hash_key(&VmValue::Nil)
948 );
949 }
950
951 #[test]
952 fn hash_key_float_zero_vs_neg_zero() {
953 let pos = VmValue::Float(0.0);
954 let neg = VmValue::Float(-0.0);
955 assert_ne!(
957 value_structural_hash_key(&pos),
958 value_structural_hash_key(&neg)
959 );
960 }
961
962 #[test]
963 fn hash_key_equal_values_match() {
964 let a = list(vec![s("hello"), i(42), VmValue::Bool(false)]);
965 let b = list(vec![s("hello"), i(42), VmValue::Bool(false)]);
966 assert_eq!(value_structural_hash_key(&a), value_structural_hash_key(&b));
967 }
968
969 #[test]
970 fn hash_key_dict_with_comma_key() {
971 let d1 = dict(vec![("a,b", i(1))]);
972 let d2 = dict(vec![("a", i(1))]);
973 assert_ne!(
974 value_structural_hash_key(&d1),
975 value_structural_hash_key(&d2)
976 );
977 }
978}