1pub use inventory;
2use runmat_gc_api::GcPtr;
3use runmat_thread_local::runmat_thread_local;
4use std::cell::RefCell;
5use std::collections::HashMap;
6use std::collections::HashSet;
7use std::convert::TryFrom;
8use std::fmt;
9use std::future::Future;
10use std::pin::Pin;
11
12use indexmap::IndexMap;
13use std::sync::OnceLock;
14
15#[cfg(target_arch = "wasm32")]
16pub mod wasm_registry {
17 use super::{BuiltinDoc, BuiltinFunction, Constant};
18 use once_cell::sync::Lazy;
19 use std::sync::Mutex;
20
21 static FUNCTIONS: Lazy<Mutex<Vec<&'static BuiltinFunction>>> =
22 Lazy::new(|| Mutex::new(Vec::new()));
23 static CONSTANTS: Lazy<Mutex<Vec<&'static Constant>>> = Lazy::new(|| Mutex::new(Vec::new()));
24 static DOCS: Lazy<Mutex<Vec<&'static BuiltinDoc>>> = Lazy::new(|| Mutex::new(Vec::new()));
25 static REGISTERED: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
26
27 fn leak<T>(value: T) -> &'static T {
28 Box::leak(Box::new(value))
29 }
30
31 pub fn submit_builtin_function(func: BuiltinFunction) {
32 let leaked = leak(func);
33 FUNCTIONS.lock().unwrap().push(leaked);
34 }
35
36 pub fn submit_constant(constant: Constant) {
37 let leaked = leak(constant);
38 CONSTANTS.lock().unwrap().push(leaked);
39 }
40
41 pub fn submit_builtin_doc(doc: BuiltinDoc) {
42 let leaked = leak(doc);
43 DOCS.lock().unwrap().push(leaked);
44 }
45
46 pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
47 FUNCTIONS.lock().unwrap().clone()
48 }
49
50 pub fn constants() -> Vec<&'static Constant> {
51 CONSTANTS.lock().unwrap().clone()
52 }
53
54 pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
55 DOCS.lock().unwrap().clone()
56 }
57
58 pub fn mark_registered() {
59 *REGISTERED.lock().unwrap() = true;
60 }
61
62 pub fn is_registered() -> bool {
63 *REGISTERED.lock().unwrap()
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
68pub enum Value {
69 Int(IntValue),
70 Num(f64),
71 Complex(f64, f64),
73 Bool(bool),
74 LogicalArray(LogicalArray),
76 String(String),
77 StringArray(StringArray),
79 CharArray(CharArray),
81 Tensor(Tensor),
82 ComplexTensor(ComplexTensor),
84 Cell(CellArray),
85 Struct(StructValue),
88 GpuTensor(runmat_accelerate_api::GpuTensorHandle),
90 Object(ObjectInstance),
92 HandleObject(HandleRef),
94 Listener(Listener),
96 OutputList(Vec<Value>),
98 FunctionHandle(String),
100 ExternalFunctionHandle(String),
102 MethodFunctionHandle(String),
104 BoundFunctionHandle {
106 name: String,
107 function: usize,
108 },
109 Closure(Closure),
110 ClassRef(String),
111 MException(MException),
112}
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum IntValue {
115 I8(i8),
116 I16(i16),
117 I32(i32),
118 I64(i64),
119 U8(u8),
120 U16(u16),
121 U32(u32),
122 U64(u64),
123}
124
125impl IntValue {
126 pub fn to_i64(&self) -> i64 {
127 match self {
128 IntValue::I8(v) => *v as i64,
129 IntValue::I16(v) => *v as i64,
130 IntValue::I32(v) => *v as i64,
131 IntValue::I64(v) => *v,
132 IntValue::U8(v) => *v as i64,
133 IntValue::U16(v) => *v as i64,
134 IntValue::U32(v) => *v as i64,
135 IntValue::U64(v) => {
136 if *v > i64::MAX as u64 {
137 i64::MAX
138 } else {
139 *v as i64
140 }
141 }
142 }
143 }
144 pub fn to_f64(&self) -> f64 {
145 self.to_i64() as f64
146 }
147 pub fn is_zero(&self) -> bool {
148 self.to_i64() == 0
149 }
150 pub fn class_name(&self) -> &'static str {
151 match self {
152 IntValue::I8(_) => "int8",
153 IntValue::I16(_) => "int16",
154 IntValue::I32(_) => "int32",
155 IntValue::I64(_) => "int64",
156 IntValue::U8(_) => "uint8",
157 IntValue::U16(_) => "uint16",
158 IntValue::U32(_) => "uint32",
159 IntValue::U64(_) => "uint64",
160 }
161 }
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct StructValue {
166 pub fields: IndexMap<String, Value>,
167}
168
169impl StructValue {
170 pub fn new() -> Self {
171 Self {
172 fields: IndexMap::new(),
173 }
174 }
175
176 pub fn insert(&mut self, name: impl Into<String>, value: Value) -> Option<Value> {
178 self.fields.insert(name.into(), value)
179 }
180
181 pub fn remove(&mut self, name: &str) -> Option<Value> {
183 self.fields.shift_remove(name)
184 }
185
186 pub fn field_names(&self) -> impl Iterator<Item = &String> {
188 self.fields.keys()
189 }
190}
191
192impl Default for StructValue {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
199pub enum NumericDType {
200 F64,
201 F32,
202 U8,
203 U16,
204}
205
206impl NumericDType {
207 pub fn class_name(self) -> &'static str {
208 match self {
209 NumericDType::F64 => "double",
210 NumericDType::F32 => "single",
211 NumericDType::U8 => "uint8",
212 NumericDType::U16 => "uint16",
213 }
214 }
215
216 pub fn byte_size(self) -> usize {
217 match self {
218 NumericDType::F64 => 8,
219 NumericDType::F32 => 4,
220 NumericDType::U8 => 1,
221 NumericDType::U16 => 2,
222 }
223 }
224}
225
226#[derive(Debug, Clone, PartialEq)]
227pub struct Tensor {
228 pub data: Vec<f64>,
229 pub shape: Vec<usize>, pub rows: usize, pub cols: usize, pub dtype: NumericDType,
234}
235
236#[derive(Debug, Clone, PartialEq)]
237pub struct ComplexTensor {
238 pub data: Vec<(f64, f64)>,
239 pub shape: Vec<usize>,
240 pub rows: usize,
241 pub cols: usize,
242}
243
244#[derive(Debug, Clone, PartialEq)]
245pub struct StringArray {
246 pub data: Vec<String>,
247 pub shape: Vec<usize>,
248 pub rows: usize,
249 pub cols: usize,
250}
251
252#[derive(Debug, Clone, PartialEq)]
253pub struct LogicalArray {
254 pub data: Vec<u8>, pub shape: Vec<usize>,
256}
257
258impl LogicalArray {
259 pub fn new(data: Vec<u8>, shape: Vec<usize>) -> Result<Self, String> {
260 let expected: usize = shape.iter().product();
261 if data.len() != expected {
262 return Err(format!(
263 "LogicalArray data length {} doesn't match shape {:?} ({} elements)",
264 data.len(),
265 shape,
266 expected
267 ));
268 }
269 let mut d = data;
271 for v in &mut d {
272 *v = if *v != 0 { 1 } else { 0 };
273 }
274 Ok(LogicalArray { data: d, shape })
275 }
276 pub fn zeros(shape: Vec<usize>) -> Self {
277 let expected: usize = shape.iter().product();
278 LogicalArray {
279 data: vec![0u8; expected],
280 shape,
281 }
282 }
283 pub fn len(&self) -> usize {
284 self.data.len()
285 }
286 pub fn is_empty(&self) -> bool {
287 self.data.is_empty()
288 }
289}
290
291#[derive(Debug, Clone, PartialEq)]
292pub struct CharArray {
293 pub data: Vec<char>,
294 pub rows: usize,
295 pub cols: usize,
296}
297
298impl CharArray {
299 pub fn new_row(s: &str) -> Self {
300 CharArray {
301 data: s.chars().collect(),
302 rows: 1,
303 cols: s.chars().count(),
304 }
305 }
306 pub fn new(data: Vec<char>, rows: usize, cols: usize) -> Result<Self, String> {
307 if rows * cols != data.len() {
308 return Err(format!(
309 "Char data length {} doesn't match dimensions {}x{}",
310 data.len(),
311 rows,
312 cols
313 ));
314 }
315 Ok(CharArray { data, rows, cols })
316 }
317}
318
319impl StringArray {
320 pub fn new(data: Vec<String>, shape: Vec<usize>) -> Result<Self, String> {
321 let expected: usize = shape.iter().product();
322 if data.len() != expected {
323 return Err(format!(
324 "StringArray data length {} doesn't match shape {:?} ({} elements)",
325 data.len(),
326 shape,
327 expected
328 ));
329 }
330 let (rows, cols) = if shape.len() >= 2 {
331 (shape[0], shape[1])
332 } else if shape.len() == 1 {
333 (1, shape[0])
334 } else {
335 (0, 0)
336 };
337 Ok(StringArray {
338 data,
339 shape,
340 rows,
341 cols,
342 })
343 }
344 pub fn new_2d(data: Vec<String>, rows: usize, cols: usize) -> Result<Self, String> {
345 Self::new(data, vec![rows, cols])
346 }
347 pub fn rows(&self) -> usize {
348 self.shape.first().copied().unwrap_or(1)
349 }
350 pub fn cols(&self) -> usize {
351 self.shape.get(1).copied().unwrap_or(1)
352 }
353}
354
355impl Tensor {
358 pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Result<Self, String> {
359 let expected: usize = shape.iter().product();
360 if data.len() != expected {
361 return Err(format!(
362 "Tensor data length {} doesn't match shape {:?} ({} elements)",
363 data.len(),
364 shape,
365 expected
366 ));
367 }
368 let (rows, cols) = if shape.len() >= 2 {
369 (shape[0], shape[1])
370 } else if shape.len() == 1 {
371 (1, shape[0])
372 } else {
373 (0, 0)
374 };
375 Ok(Tensor {
376 data,
377 shape,
378 rows,
379 cols,
380 dtype: NumericDType::F64,
381 })
382 }
383
384 pub fn new_2d(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
385 Self::new(data, vec![rows, cols])
386 }
387
388 pub fn from_f32(data: Vec<f32>, shape: Vec<usize>) -> Result<Self, String> {
389 let converted: Vec<f64> = data.into_iter().map(|v| v as f64).collect();
390 Self::new_with_dtype(converted, shape, NumericDType::F32)
391 }
392
393 pub fn from_f32_slice(data: &[f32], shape: &[usize]) -> Result<Self, String> {
394 let converted: Vec<f64> = data.iter().map(|&v| v as f64).collect();
395 Self::new_with_dtype(converted, shape.to_vec(), NumericDType::F32)
396 }
397
398 pub fn new_with_dtype(
399 data: Vec<f64>,
400 shape: Vec<usize>,
401 dtype: NumericDType,
402 ) -> Result<Self, String> {
403 let mut t = Self::new(data, shape)?;
404 t.dtype = dtype;
405 Ok(t)
406 }
407
408 pub fn zeros(shape: Vec<usize>) -> Self {
409 let size: usize = shape.iter().product();
410 let (rows, cols) = if shape.len() >= 2 {
411 (shape[0], shape[1])
412 } else if shape.len() == 1 {
413 (1, shape[0])
414 } else {
415 (0, 0)
416 };
417 Tensor {
418 data: vec![0.0; size],
419 shape,
420 rows,
421 cols,
422 dtype: NumericDType::F64,
423 }
424 }
425
426 pub fn ones(shape: Vec<usize>) -> Self {
427 let size: usize = shape.iter().product();
428 let (rows, cols) = if shape.len() >= 2 {
429 (shape[0], shape[1])
430 } else if shape.len() == 1 {
431 (1, shape[0])
432 } else {
433 (0, 0)
434 };
435 Tensor {
436 data: vec![1.0; size],
437 shape,
438 rows,
439 cols,
440 dtype: NumericDType::F64,
441 }
442 }
443
444 pub fn zeros2(rows: usize, cols: usize) -> Self {
446 Self::zeros(vec![rows, cols])
447 }
448 pub fn ones2(rows: usize, cols: usize) -> Self {
449 Self::ones(vec![rows, cols])
450 }
451
452 pub fn rows(&self) -> usize {
453 self.shape.first().copied().unwrap_or(1)
454 }
455 pub fn cols(&self) -> usize {
456 self.shape.get(1).copied().unwrap_or(1)
457 }
458
459 pub fn get2(&self, row: usize, col: usize) -> Result<f64, String> {
460 let rows = self.rows();
461 let cols = self.cols();
462 if row >= rows || col >= cols {
463 return Err(format!(
464 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
465 ));
466 }
467 Ok(self.data[row + col * rows])
469 }
470
471 pub fn set2(&mut self, row: usize, col: usize, value: f64) -> Result<(), String> {
472 let rows = self.rows();
473 let cols = self.cols();
474 if row >= rows || col >= cols {
475 return Err(format!(
476 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
477 ));
478 }
479 self.data[row + col * rows] = value;
481 Ok(())
482 }
483
484 pub fn scalar_to_tensor2(scalar: f64, rows: usize, cols: usize) -> Tensor {
485 Tensor {
486 data: vec![scalar; rows * cols],
487 shape: vec![rows, cols],
488 rows,
489 cols,
490 dtype: NumericDType::F64,
491 }
492 }
493 }
495
496impl ComplexTensor {
497 pub fn new(data: Vec<(f64, f64)>, shape: Vec<usize>) -> Result<Self, String> {
498 let expected: usize = shape.iter().product();
499 if data.len() != expected {
500 return Err(format!(
501 "ComplexTensor data length {} doesn't match shape {:?} ({} elements)",
502 data.len(),
503 shape,
504 expected
505 ));
506 }
507 let (rows, cols) = if shape.len() >= 2 {
508 (shape[0], shape[1])
509 } else if shape.len() == 1 {
510 (1, shape[0])
511 } else {
512 (0, 0)
513 };
514 Ok(ComplexTensor {
515 data,
516 shape,
517 rows,
518 cols,
519 })
520 }
521 pub fn new_2d(data: Vec<(f64, f64)>, rows: usize, cols: usize) -> Result<Self, String> {
522 Self::new(data, vec![rows, cols])
523 }
524 pub fn zeros(shape: Vec<usize>) -> Self {
525 let size: usize = shape.iter().product();
526 let (rows, cols) = if shape.len() >= 2 {
527 (shape[0], shape[1])
528 } else if shape.len() == 1 {
529 (1, shape[0])
530 } else {
531 (0, 0)
532 };
533 ComplexTensor {
534 data: vec![(0.0, 0.0); size],
535 shape,
536 rows,
537 cols,
538 }
539 }
540}
541
542const MAX_ND_DISPLAY_ELEMENTS: usize = 4096;
543
544fn should_expand_nd_display(shape: &[usize]) -> bool {
545 shape.len() > 2
546 && matches!(
547 total_len(shape),
548 Some(total) if total > 0 && total <= MAX_ND_DISPLAY_ELEMENTS
549 )
550}
551
552fn column_major_strides(shape: &[usize]) -> Vec<usize> {
553 let mut strides = Vec::with_capacity(shape.len());
554 let mut stride = 1usize;
555 for &dim in shape {
556 strides.push(stride);
557 stride = stride.saturating_mul(dim);
558 }
559 strides
560}
561
562fn decode_page_coords(mut page_index: usize, page_shape: &[usize]) -> Vec<usize> {
563 let mut coords = Vec::with_capacity(page_shape.len());
564 for &dim in page_shape {
565 if dim == 0 {
566 coords.push(0);
567 } else {
568 coords.push(page_index % dim);
569 page_index /= dim;
570 }
571 }
572 coords
573}
574
575fn write_nd_pages(
576 f: &mut fmt::Formatter<'_>,
577 shape: &[usize],
578 mut write_element: impl FnMut(&mut fmt::Formatter<'_>, usize) -> fmt::Result,
579) -> fmt::Result {
580 if shape.len() <= 2 {
581 return Ok(());
582 }
583 let rows = shape[0];
584 let cols = shape[1];
585 if rows == 0 || cols == 0 {
586 return write!(f, "[]");
587 }
588 let Some(page_count) = total_len(&shape[2..]) else {
589 return write!(f, "Tensor(shape={shape:?})");
590 };
591 if page_count == 0 {
592 return write!(f, "[]");
593 }
594 let strides = column_major_strides(shape);
595 for page_index in 0..page_count {
596 if page_index > 0 {
597 write!(f, "\n\n")?;
598 }
599 let coords = decode_page_coords(page_index, &shape[2..]);
600 write!(f, "(:, :")?;
601 for &coord in &coords {
602 write!(f, ", {}", coord + 1)?;
603 }
604 write!(f, ") =")?;
605
606 let mut page_base = 0usize;
607 for (offset, &coord) in coords.iter().enumerate() {
608 page_base += coord * strides[offset + 2];
609 }
610 for r in 0..rows {
611 writeln!(f)?;
612 write!(f, " ")?;
613 for c in 0..cols {
614 if c > 0 {
615 write!(f, " ")?;
616 }
617 let linear = page_base + r + c * rows;
618 write_element(f, linear)?;
619 }
620 }
621 }
622 Ok(())
623}
624
625impl fmt::Display for Tensor {
626 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
627 match self.shape.len() {
628 0 | 1 => {
629 write!(f, "[")?;
631 for (i, v) in self.data.iter().enumerate() {
632 if i > 0 {
633 write!(f, " ")?;
634 }
635 write!(f, "{}", format_number(*v))?;
636 }
637 write!(f, "]")
638 }
639 2 => {
640 let rows = self.rows();
641 let cols = self.cols();
642 for r in 0..rows {
644 writeln!(f)?;
645 write!(f, " ")?; for c in 0..cols {
647 if c > 0 {
648 write!(f, " ")?;
649 }
650 let v = self.data[r + c * rows];
651 write!(f, "{}", format_number(v))?;
652 }
653 }
654 Ok(())
655 }
656 _ => {
657 if should_expand_nd_display(&self.shape) {
658 write_nd_pages(f, &self.shape, |f, idx| {
659 write!(f, "{}", format_number(self.data[idx]))
660 })
661 } else {
662 write!(f, "Tensor(shape={:?})", self.shape)
663 }
664 }
665 }
666 }
667}
668
669impl fmt::Display for StringArray {
670 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
671 let (rows, cols) = match self.shape.len() {
672 0 => (0, 0),
673 1 => (1, self.shape[0]),
674 _ => (self.shape[0], self.shape[1]),
675 };
676 let count = self.data.len();
677 if count == 1 && rows == 1 && cols == 1 {
678 let v = &self.data[0];
679 if v == "<missing>" {
680 return write!(f, "<missing>");
681 }
682 let escaped = v.replace('"', "\\\"");
683 return write!(f, "\"{escaped}\"");
684 }
685 if self.shape.len() > 2 {
686 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
687 return write!(f, "{} string array", dims.join("x"));
688 }
689 write!(f, "{rows}x{cols} string array")?;
690 if rows == 0 || cols == 0 {
691 return Ok(());
692 }
693 for r in 0..rows {
694 writeln!(f)?;
695 write!(f, " ")?;
696 for c in 0..cols {
697 if c > 0 {
698 write!(f, " ")?;
699 }
700 let v = &self.data[r + c * rows];
701 if v == "<missing>" {
702 write!(f, "<missing>")?;
703 } else {
704 let escaped = v.replace('"', "\\\"");
705 write!(f, "\"{escaped}\"")?;
706 }
707 }
708 }
709 Ok(())
710 }
711}
712
713impl fmt::Display for LogicalArray {
714 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
715 if self.data.len() == 1 {
716 return write!(f, "{}", if self.data[0] != 0 { 1 } else { 0 });
717 }
718 match self.shape.len() {
719 0 => write!(f, "[]"),
720 1 => {
721 write!(f, "[")?;
722 for (i, v) in self.data.iter().enumerate() {
723 if i > 0 {
724 write!(f, " ")?;
725 }
726 write!(f, "{}", if *v != 0 { 1 } else { 0 })?;
727 }
728 write!(f, "]")
729 }
730 2 => {
731 let rows = self.shape[0];
732 let cols = self.shape[1];
733 for r in 0..rows {
735 writeln!(f)?;
736 write!(f, " ")?; for c in 0..cols {
738 if c > 0 {
739 write!(f, " ")?;
740 }
741 let idx = r + c * rows;
742 write!(f, "{}", if self.data[idx] != 0 { 1 } else { 0 })?;
743 }
744 }
745 Ok(())
746 }
747 _ => {
748 if should_expand_nd_display(&self.shape) {
749 write_nd_pages(f, &self.shape, |f, idx| {
750 write!(f, "{}", if self.data[idx] != 0 { 1 } else { 0 })
751 })
752 } else {
753 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
754 write!(f, "{} logical array", dims.join("x"))
755 }
756 }
757 }
758 }
759}
760
761impl fmt::Display for CharArray {
762 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
763 for r in 0..self.rows {
764 writeln!(f)?;
765 write!(f, " ")?; for c in 0..self.cols {
767 let ch = self.data[r * self.cols + c];
768 write!(f, "{ch}")?;
769 }
770 }
771 Ok(())
772 }
773}
774
775impl From<i32> for Value {
777 fn from(i: i32) -> Self {
778 Value::Int(IntValue::I32(i))
779 }
780}
781impl From<i64> for Value {
782 fn from(i: i64) -> Self {
783 Value::Int(IntValue::I64(i))
784 }
785}
786impl From<u32> for Value {
787 fn from(i: u32) -> Self {
788 Value::Int(IntValue::U32(i))
789 }
790}
791impl From<u64> for Value {
792 fn from(i: u64) -> Self {
793 Value::Int(IntValue::U64(i))
794 }
795}
796impl From<i16> for Value {
797 fn from(i: i16) -> Self {
798 Value::Int(IntValue::I16(i))
799 }
800}
801impl From<i8> for Value {
802 fn from(i: i8) -> Self {
803 Value::Int(IntValue::I8(i))
804 }
805}
806impl From<u16> for Value {
807 fn from(i: u16) -> Self {
808 Value::Int(IntValue::U16(i))
809 }
810}
811impl From<u8> for Value {
812 fn from(i: u8) -> Self {
813 Value::Int(IntValue::U8(i))
814 }
815}
816
817impl From<f64> for Value {
818 fn from(f: f64) -> Self {
819 Value::Num(f)
820 }
821}
822
823impl From<bool> for Value {
824 fn from(b: bool) -> Self {
825 Value::Bool(b)
826 }
827}
828
829impl From<String> for Value {
830 fn from(s: String) -> Self {
831 Value::String(s)
832 }
833}
834
835impl From<&str> for Value {
836 fn from(s: &str) -> Self {
837 Value::String(s.to_string())
838 }
839}
840
841impl From<Tensor> for Value {
842 fn from(m: Tensor) -> Self {
843 Value::Tensor(m)
844 }
845}
846
847impl TryFrom<&Value> for i32 {
851 type Error = String;
852 fn try_from(v: &Value) -> Result<Self, Self::Error> {
853 match v {
854 Value::Int(i) => Ok(i.to_i64() as i32),
855 Value::Num(n) => Ok(*n as i32),
856 _ => Err(format!("cannot convert {v:?} to i32")),
857 }
858 }
859}
860
861impl TryFrom<&Value> for f64 {
862 type Error = String;
863 fn try_from(v: &Value) -> Result<Self, Self::Error> {
864 match v {
865 Value::Num(n) => Ok(*n),
866 Value::Int(i) => Ok(i.to_f64()),
867 _ => Err(format!("cannot convert {v:?} to f64")),
868 }
869 }
870}
871
872impl TryFrom<&Value> for bool {
873 type Error = String;
874 fn try_from(v: &Value) -> Result<Self, Self::Error> {
875 match v {
876 Value::Bool(b) => Ok(*b),
877 Value::Int(i) => Ok(!i.is_zero()),
878 Value::Num(n) => Ok(*n != 0.0),
879 _ => Err(format!("cannot convert {v:?} to bool")),
880 }
881 }
882}
883
884impl TryFrom<&Value> for String {
885 type Error = String;
886 fn try_from(v: &Value) -> Result<Self, Self::Error> {
887 match v {
888 Value::String(s) => Ok(s.clone()),
889 Value::StringArray(sa) => {
890 if sa.data.len() == 1 {
891 Ok(sa.data[0].clone())
892 } else {
893 Err("cannot convert string array to scalar string".to_string())
894 }
895 }
896 Value::CharArray(ca) => {
897 if ca.rows == 1 {
899 Ok(ca.data.iter().collect())
900 } else {
901 Err("cannot convert multi-row char array to scalar string".to_string())
902 }
903 }
904 Value::Int(i) => Ok(i.to_i64().to_string()),
905 Value::Num(n) => Ok(n.to_string()),
906 Value::Bool(b) => Ok(b.to_string()),
907 _ => Err(format!("cannot convert {v:?} to String")),
908 }
909 }
910}
911
912impl TryFrom<&Value> for Tensor {
913 type Error = String;
914 fn try_from(v: &Value) -> Result<Self, Self::Error> {
915 match v {
916 Value::Tensor(m) => Ok(m.clone()),
917 _ => Err(format!("cannot convert {v:?} to Tensor")),
918 }
919 }
920}
921
922impl TryFrom<&Value> for Value {
923 type Error = String;
924 fn try_from(v: &Value) -> Result<Self, Self::Error> {
925 Ok(v.clone())
926 }
927}
928
929impl TryFrom<&Value> for Vec<Value> {
930 type Error = String;
931 fn try_from(v: &Value) -> Result<Self, Self::Error> {
932 match v {
933 Value::Cell(c) => Ok(c.data.iter().map(|p| (**p).clone()).collect()),
934 _ => Err(format!("cannot convert {v:?} to Vec<Value>")),
935 }
936 }
937}
938
939use serde::{Deserialize, Serialize};
940
941#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
944pub enum Type {
945 Int,
947 Num,
949 Bool,
951 Logical {
953 shape: Option<Vec<Option<usize>>>,
955 },
956 String,
958 Tensor {
960 shape: Option<Vec<Option<usize>>>,
962 },
963 Cell {
965 element_type: Option<Box<Type>>,
967 length: Option<usize>,
969 },
970 Function {
972 params: Vec<Type>,
974 returns: Box<Type>,
976 },
977 Void,
979 Unknown,
981 Union(Vec<Type>),
983 Struct {
985 known_fields: Option<Vec<String>>, },
988 OutputList(Vec<Type>),
990}
991
992impl Type {
993 pub fn tensor() -> Self {
995 Type::Tensor { shape: None }
996 }
997
998 pub fn logical() -> Self {
1000 Type::Logical { shape: None }
1001 }
1002
1003 pub fn logical_with_shape(shape: Vec<usize>) -> Self {
1005 Type::Logical {
1006 shape: Some(shape.into_iter().map(Some).collect()),
1007 }
1008 }
1009
1010 pub fn tensor_with_shape(shape: Vec<usize>) -> Self {
1012 Type::Tensor {
1013 shape: Some(shape.into_iter().map(Some).collect()),
1014 }
1015 }
1016
1017 pub fn cell() -> Self {
1019 Type::Cell {
1020 element_type: None,
1021 length: None,
1022 }
1023 }
1024
1025 pub fn cell_of(element_type: Type) -> Self {
1027 Type::Cell {
1028 element_type: Some(Box::new(element_type)),
1029 length: None,
1030 }
1031 }
1032
1033 pub fn is_compatible_with(&self, other: &Type) -> bool {
1035 match (self, other) {
1036 (Type::Unknown, _) | (_, Type::Unknown) => true,
1037 (Type::Int, Type::Num) | (Type::Num, Type::Int) => true, (Type::Tensor { .. }, Type::Tensor { .. }) => true, (Type::OutputList(a), Type::OutputList(b)) => a.len() == b.len(),
1040 (a, b) => a == b,
1041 }
1042 }
1043
1044 pub fn unify(&self, other: &Type) -> Type {
1046 match (self, other) {
1047 (Type::Unknown, t) | (t, Type::Unknown) => t.clone(),
1048 (Type::Int, Type::Num) | (Type::Num, Type::Int) => Type::Num,
1049 (Type::Tensor { shape: a }, Type::Tensor { shape: b }) => {
1050 let a_norm = match a {
1051 Some(dims) if dims.is_empty() => None,
1052 _ => a.clone(),
1053 };
1054 let b_norm = match b {
1055 Some(dims) if dims.is_empty() => None,
1056 _ => b.clone(),
1057 };
1058 let a_unknown = a_norm
1059 .as_ref()
1060 .map(|dims| dims.iter().all(|d| d.is_none()))
1061 .unwrap_or(true);
1062 let b_unknown = b_norm
1063 .as_ref()
1064 .map(|dims| dims.iter().all(|d| d.is_none()))
1065 .unwrap_or(true);
1066 if a_norm == b_norm
1067 || (!a_unknown && b_unknown)
1068 || (a_norm.is_some() && b_norm.is_none())
1069 {
1070 Type::Tensor { shape: a_norm }
1071 } else if (a_unknown && !b_unknown) || (a_norm.is_none() && b_norm.is_some()) {
1072 Type::Tensor { shape: b_norm }
1073 } else {
1074 Type::tensor()
1075 }
1076 }
1077 (Type::Logical { shape: a }, Type::Logical { shape: b }) => {
1078 let a_norm = match a {
1079 Some(dims) if dims.is_empty() => None,
1080 _ => a.clone(),
1081 };
1082 let b_norm = match b {
1083 Some(dims) if dims.is_empty() => None,
1084 _ => b.clone(),
1085 };
1086 let a_unknown = a_norm
1087 .as_ref()
1088 .map(|dims| dims.iter().all(|d| d.is_none()))
1089 .unwrap_or(true);
1090 let b_unknown = b_norm
1091 .as_ref()
1092 .map(|dims| dims.iter().all(|d| d.is_none()))
1093 .unwrap_or(true);
1094 if a_norm == b_norm
1095 || (!a_unknown && b_unknown)
1096 || (a_norm.is_some() && b_norm.is_none())
1097 {
1098 Type::Logical { shape: a_norm }
1099 } else if (a_unknown && !b_unknown) || (a_norm.is_none() && b_norm.is_some()) {
1100 Type::Logical { shape: b_norm }
1101 } else {
1102 Type::logical()
1103 }
1104 }
1105 (Type::Struct { known_fields: a }, Type::Struct { known_fields: b }) => match (a, b) {
1106 (None, None) => Type::Struct { known_fields: None },
1107 (Some(ka), None) | (None, Some(ka)) => Type::Struct {
1108 known_fields: Some(ka.clone()),
1109 },
1110 (Some(ka), Some(kb)) => {
1111 let mut set: std::collections::BTreeSet<String> = ka.iter().cloned().collect();
1112 set.extend(kb.iter().cloned());
1113 Type::Struct {
1114 known_fields: Some(set.into_iter().collect()),
1115 }
1116 }
1117 },
1118 (Type::OutputList(a), Type::OutputList(b)) => {
1119 if a.len() == b.len() {
1120 let items = a
1121 .iter()
1122 .zip(b.iter())
1123 .map(|(lhs, rhs)| lhs.unify(rhs))
1124 .collect();
1125 Type::OutputList(items)
1126 } else {
1127 Type::OutputList(vec![Type::Unknown; a.len().max(b.len())])
1128 }
1129 }
1130 (a, b) if a == b => a.clone(),
1131 _ => Type::Union(vec![self.clone(), other.clone()]),
1132 }
1133 }
1134
1135 pub fn from_value(value: &Value) -> Type {
1137 match value {
1138 Value::Int(_) => Type::Int,
1139 Value::Num(_) => Type::Num,
1140 Value::Complex(_, _) => Type::Num, Value::Bool(_) => Type::Bool,
1142 Value::LogicalArray(arr) => Type::Logical {
1143 shape: Some(arr.shape.iter().map(|&d| Some(d)).collect()),
1144 },
1145 Value::String(_) => Type::String,
1146 Value::StringArray(_sa) => {
1147 Type::cell_of(Type::String)
1149 }
1150 Value::Tensor(t) => Type::Tensor {
1151 shape: Some(t.shape.iter().map(|&d| Some(d)).collect()),
1152 },
1153 Value::ComplexTensor(t) => Type::Tensor {
1154 shape: Some(t.shape.iter().map(|&d| Some(d)).collect()),
1155 },
1156 Value::Cell(cells) => {
1157 if cells.data.is_empty() {
1158 Type::cell()
1159 } else {
1160 let element_type = Type::from_value(&cells.data[0]);
1162 Type::Cell {
1163 element_type: Some(Box::new(element_type)),
1164 length: Some(cells.data.len()),
1165 }
1166 }
1167 }
1168 Value::GpuTensor(h) => Type::Tensor {
1169 shape: Some(h.shape.iter().map(|&d| Some(d)).collect()),
1170 },
1171 Value::Object(_) => Type::Unknown,
1172 Value::HandleObject(_) => Type::Unknown,
1173 Value::Listener(_) => Type::Unknown,
1174 Value::Struct(_) => Type::Struct { known_fields: None },
1175 Value::FunctionHandle(_)
1176 | Value::ExternalFunctionHandle(_)
1177 | Value::MethodFunctionHandle(_)
1178 | Value::BoundFunctionHandle { .. } => Type::Function {
1179 params: vec![Type::Unknown],
1180 returns: Box::new(Type::Unknown),
1181 },
1182 Value::Closure(_) => Type::Function {
1183 params: vec![Type::Unknown],
1184 returns: Box::new(Type::Unknown),
1185 },
1186 Value::ClassRef(_) => Type::Unknown,
1187 Value::MException(_) => Type::Unknown,
1188 Value::CharArray(ca) => {
1189 Type::Cell {
1191 element_type: Some(Box::new(Type::String)),
1192 length: Some(ca.rows * ca.cols),
1193 }
1194 }
1195 Value::OutputList(values) => {
1196 Type::OutputList(values.iter().map(Type::from_value).collect())
1197 }
1198 }
1199 }
1200}
1201
1202#[derive(Debug, Clone, PartialEq)]
1203pub struct Closure {
1204 pub function_name: String,
1205 pub bound_function: Option<usize>,
1206 pub captures: Vec<Value>,
1207}
1208
1209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1211pub enum AccelTag {
1212 Unary,
1213 Elementwise,
1214 Reduction,
1215 MatMul,
1216 Transpose,
1217 ArrayConstruct,
1218}
1219
1220pub type BuiltinControlFlow = runmat_async::RuntimeError;
1222
1223pub type BuiltinFuture = Pin<Box<dyn Future<Output = Result<Value, BuiltinControlFlow>> + 'static>>;
1225
1226#[derive(Clone, Debug, Default)]
1227pub struct ResolveContext {
1228 pub literal_args: Vec<LiteralValue>,
1229}
1230
1231#[derive(Clone, Debug, PartialEq)]
1232pub enum LiteralValue {
1233 Number(f64),
1234 Bool(bool),
1235 String(String),
1236 Vector(Vec<LiteralValue>),
1237 Unknown,
1238}
1239
1240impl ResolveContext {
1241 pub fn new(literal_args: Vec<LiteralValue>) -> Self {
1242 Self { literal_args }
1243 }
1244
1245 pub fn numeric_dims(&self) -> Vec<Option<usize>> {
1246 self.numeric_dims_from(0)
1247 }
1248
1249 pub fn numeric_dims_from(&self, start: usize) -> Vec<Option<usize>> {
1250 let slice = self.literal_args.get(start..).unwrap_or(&[]);
1251 if let Some(LiteralValue::Vector(values)) = slice.first() {
1252 return values
1253 .iter()
1254 .map(Self::numeric_dimension_from_literal)
1255 .collect();
1256 }
1257 slice
1258 .iter()
1259 .map(Self::numeric_dimension_from_literal)
1260 .collect()
1261 }
1262
1263 pub fn literal_string_at(&self, index: usize) -> Option<String> {
1264 match self.literal_args.get(index) {
1265 Some(LiteralValue::String(value)) => Some(value.to_ascii_lowercase()),
1266 _ => None,
1267 }
1268 }
1269
1270 pub fn literal_bool_at(&self, index: usize) -> Option<bool> {
1271 match self.literal_args.get(index) {
1272 Some(LiteralValue::Bool(value)) => Some(*value),
1273 _ => None,
1274 }
1275 }
1276
1277 pub fn literal_vector_at(&self, index: usize) -> Option<Vec<LiteralValue>> {
1278 match self.literal_args.get(index) {
1279 Some(LiteralValue::Vector(values)) => Some(values.clone()),
1280 _ => None,
1281 }
1282 }
1283
1284 pub fn numeric_vector_at(&self, index: usize) -> Option<Vec<Option<usize>>> {
1285 let values = match self.literal_args.get(index) {
1286 Some(LiteralValue::Vector(values)) => values,
1287 _ => return None,
1288 };
1289 if values
1290 .iter()
1291 .any(|value| matches!(value, LiteralValue::Vector(_)))
1292 {
1293 return None;
1294 }
1295 Some(
1296 values
1297 .iter()
1298 .map(Self::numeric_dimension_from_literal)
1299 .collect(),
1300 )
1301 }
1302
1303 fn numeric_dimension_from_literal(value: &LiteralValue) -> Option<usize> {
1304 match value {
1305 LiteralValue::Number(num) => {
1306 if num.is_finite() {
1307 let rounded = num.round();
1308 if (num - rounded).abs() <= 1e-9 && rounded >= 0.0 {
1309 return Some(rounded as usize);
1310 }
1311 }
1312 None
1313 }
1314 _ => None,
1315 }
1316 }
1317}
1318
1319#[cfg(test)]
1320mod resolve_context_tests {
1321 use super::{LiteralValue, ResolveContext};
1322
1323 #[test]
1324 fn numeric_dims_reads_vector_literal() {
1325 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
1326 LiteralValue::Number(2.0),
1327 LiteralValue::Number(3.0),
1328 ])]);
1329 assert_eq!(ctx.numeric_dims(), vec![Some(2), Some(3)]);
1330 }
1331
1332 #[test]
1333 fn numeric_dims_skips_non_numeric_entries() {
1334 let ctx = ResolveContext::new(vec![
1335 LiteralValue::Number(4.0),
1336 LiteralValue::String("like".to_string()),
1337 LiteralValue::Unknown,
1338 ]);
1339 assert_eq!(ctx.numeric_dims(), vec![Some(4), None, None]);
1340 }
1341
1342 #[test]
1343 fn numeric_dims_prefers_vector_even_with_trailing_args() {
1344 let ctx = ResolveContext::new(vec![
1345 LiteralValue::Vector(vec![LiteralValue::Number(1.0), LiteralValue::Number(5.0)]),
1346 LiteralValue::String("like".to_string()),
1347 ]);
1348 assert_eq!(ctx.numeric_dims(), vec![Some(1), Some(5)]);
1349 }
1350
1351 #[test]
1352 fn literal_string_is_lowercased() {
1353 let ctx = ResolveContext::new(vec![LiteralValue::String("OmItNaN".to_string())]);
1354 assert_eq!(ctx.literal_string_at(0), Some("omitnan".to_string()));
1355 }
1356
1357 #[test]
1358 fn literal_bool_is_available() {
1359 let ctx = ResolveContext::new(vec![LiteralValue::Bool(true)]);
1360 assert_eq!(ctx.literal_bool_at(0), Some(true));
1361 }
1362
1363 #[test]
1364 fn literal_vector_at_returns_clone() {
1365 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
1366 LiteralValue::Number(7.0),
1367 LiteralValue::Unknown,
1368 ])]);
1369 assert_eq!(
1370 ctx.literal_vector_at(0),
1371 Some(vec![LiteralValue::Number(7.0), LiteralValue::Unknown])
1372 );
1373 }
1374
1375 #[test]
1376 fn numeric_vector_at_rejects_nested_vectors() {
1377 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![LiteralValue::Vector(
1378 vec![LiteralValue::Number(1.0)],
1379 )])]);
1380 assert_eq!(ctx.numeric_vector_at(0), None);
1381 }
1382}
1383
1384pub type TypeResolver = fn(args: &[Type]) -> Type;
1385pub type TypeResolverWithContext = fn(args: &[Type], ctx: &ResolveContext) -> Type;
1386
1387#[derive(Clone, Copy, Debug)]
1388pub enum TypeResolverKind {
1389 Simple(TypeResolver),
1390 WithContext(TypeResolverWithContext),
1391}
1392
1393pub fn type_resolver_kind(resolver: TypeResolver) -> TypeResolverKind {
1394 TypeResolverKind::Simple(resolver)
1395}
1396
1397pub fn type_resolver_kind_ctx(resolver: TypeResolverWithContext) -> TypeResolverKind {
1398 TypeResolverKind::WithContext(resolver)
1399}
1400
1401#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1402pub enum BuiltinOutputMode {
1403 Fixed,
1404 ByRequestedOutputCount,
1405}
1406
1407#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1408pub enum BuiltinCompletionPolicy {
1409 Public,
1410 MethodOnly,
1411 HiddenInternal,
1412}
1413
1414#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1415pub enum BuiltinParamArity {
1416 Required,
1417 Optional,
1418 Variadic,
1419}
1420
1421#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1422pub enum BuiltinParamType {
1423 Any,
1424 NumericScalar,
1425 IntegerScalar,
1426 StringScalar,
1427 NumericArray,
1428 LogicalArray,
1429 SizeArg,
1430 LikePrototype,
1431 AxesHandle,
1432 StyleSpec,
1433 PropertyName,
1434 PropertyValue,
1435}
1436
1437#[derive(Debug, Clone, Serialize)]
1438pub struct BuiltinParamDescriptor {
1439 pub name: &'static str,
1440 pub ty: BuiltinParamType,
1441 pub arity: BuiltinParamArity,
1442 pub default: Option<&'static str>,
1443 pub description: &'static str,
1444}
1445
1446#[derive(Debug, Clone, Serialize)]
1447pub struct BuiltinSignatureDescriptor {
1448 pub label: &'static str,
1449 pub inputs: &'static [BuiltinParamDescriptor],
1450 pub outputs: &'static [BuiltinParamDescriptor],
1451}
1452
1453#[derive(Debug, Clone, Serialize)]
1454pub struct BuiltinErrorDescriptor {
1455 pub code: &'static str,
1456 pub identifier: Option<&'static str>,
1457 pub when: &'static str,
1458 pub message: &'static str,
1459}
1460
1461#[derive(Debug, Clone, Serialize)]
1462pub struct BuiltinDescriptor {
1463 pub signatures: &'static [BuiltinSignatureDescriptor],
1464 pub output_mode: BuiltinOutputMode,
1465 pub completion_policy: BuiltinCompletionPolicy,
1466 pub errors: &'static [BuiltinErrorDescriptor],
1467}
1468
1469#[derive(Debug, Clone)]
1471pub struct BuiltinFunction {
1472 pub name: &'static str,
1473 pub description: &'static str,
1474 pub category: &'static str,
1475 pub doc: &'static str,
1476 pub examples: &'static str,
1477 pub param_types: Vec<Type>,
1478 pub return_type: Type,
1479 pub type_resolver: Option<TypeResolverKind>,
1480 pub implementation: fn(&[Value]) -> BuiltinFuture,
1481 pub accel_tags: &'static [AccelTag],
1482 pub is_sink: bool,
1483 pub suppress_auto_output: bool,
1484 pub descriptor: Option<&'static BuiltinDescriptor>,
1485}
1486
1487impl BuiltinFunction {
1488 #[allow(clippy::too_many_arguments)]
1489 pub fn new(
1490 name: &'static str,
1491 description: &'static str,
1492 category: &'static str,
1493 doc: &'static str,
1494 examples: &'static str,
1495 param_types: Vec<Type>,
1496 return_type: Type,
1497 type_resolver: Option<TypeResolverKind>,
1498 implementation: fn(&[Value]) -> BuiltinFuture,
1499 accel_tags: &'static [AccelTag],
1500 is_sink: bool,
1501 suppress_auto_output: bool,
1502 ) -> Self {
1503 Self {
1504 name,
1505 description,
1506 category,
1507 doc,
1508 examples,
1509 param_types,
1510 return_type,
1511 type_resolver,
1512 implementation,
1513 accel_tags,
1514 is_sink,
1515 suppress_auto_output,
1516 descriptor: None,
1517 }
1518 }
1519
1520 pub fn with_descriptor(mut self, descriptor: &'static BuiltinDescriptor) -> Self {
1521 self.descriptor = Some(descriptor);
1522 self
1523 }
1524
1525 pub fn with_descriptor_option(
1526 mut self,
1527 descriptor: Option<&'static BuiltinDescriptor>,
1528 ) -> Self {
1529 self.descriptor = descriptor;
1530 self
1531 }
1532
1533 pub fn infer_return_type(&self, args: &[Type]) -> Type {
1534 self.infer_return_type_with_context(args, &ResolveContext::default())
1535 }
1536
1537 pub fn infer_return_type_with_context(&self, args: &[Type], ctx: &ResolveContext) -> Type {
1538 if let Some(resolver) = self.type_resolver {
1539 return match resolver {
1540 TypeResolverKind::Simple(resolver) => resolver(args),
1541 TypeResolverKind::WithContext(resolver) => resolver(args, ctx),
1542 };
1543 }
1544 self.return_type.clone()
1545 }
1546
1547 pub fn semantics(&self) -> BuiltinSemantics {
1548 semantics::builtin_semantics_for(self)
1549 }
1550}
1551
1552#[derive(Clone)]
1554pub struct Constant {
1555 pub name: &'static str,
1556 pub value: Value,
1557}
1558
1559pub mod semantics;
1560pub mod shape_rules;
1561
1562pub use semantics::{
1563 builtin_semantics_for, builtin_semantics_for_name, BuiltinAsyncBehavior, BuiltinCompatibility,
1564 BuiltinEffects, BuiltinEnvironmentEffect, BuiltinPurity, BuiltinSemanticKind, BuiltinSemantics,
1565 BuiltinWorkspaceEffect, ConcatKind, ShapeTransformKind,
1566};
1567
1568impl std::fmt::Debug for Constant {
1569 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1570 write!(
1571 f,
1572 "Constant {{ name: {:?}, value: {:?} }}",
1573 self.name, self.value
1574 )
1575 }
1576}
1577
1578#[cfg(not(target_arch = "wasm32"))]
1579inventory::collect!(BuiltinFunction);
1580#[cfg(not(target_arch = "wasm32"))]
1581inventory::collect!(Constant);
1582
1583#[cfg(not(target_arch = "wasm32"))]
1584pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
1585 inventory::iter::<BuiltinFunction>().collect()
1586}
1587
1588#[cfg(target_arch = "wasm32")]
1589pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
1590 wasm_registry::builtin_functions()
1591}
1592
1593#[cfg(not(target_arch = "wasm32"))]
1594static BUILTIN_LOOKUP: OnceLock<HashMap<String, &'static BuiltinFunction>> = OnceLock::new();
1595
1596#[cfg(not(target_arch = "wasm32"))]
1597fn builtin_lookup_map() -> &'static HashMap<String, &'static BuiltinFunction> {
1598 BUILTIN_LOOKUP.get_or_init(|| {
1599 let mut map = HashMap::new();
1600 for func in builtin_functions() {
1601 map.insert(func.name.to_ascii_lowercase(), func);
1602 }
1603 map
1604 })
1605}
1606
1607#[cfg(not(target_arch = "wasm32"))]
1608pub fn builtin_function_by_name(name: &str) -> Option<&'static BuiltinFunction> {
1609 builtin_lookup_map()
1610 .get(&name.to_ascii_lowercase())
1611 .copied()
1612}
1613
1614#[cfg(target_arch = "wasm32")]
1615pub fn builtin_function_by_name(name: &str) -> Option<&'static BuiltinFunction> {
1616 wasm_registry::builtin_functions()
1617 .into_iter()
1618 .find(|f| f.name.eq_ignore_ascii_case(name))
1619}
1620
1621pub fn suppresses_auto_output(name: &str) -> bool {
1622 builtin_function_by_name(name)
1623 .map(|f| f.suppress_auto_output)
1624 .unwrap_or(false)
1625}
1626
1627#[cfg(not(target_arch = "wasm32"))]
1628pub fn constants() -> Vec<&'static Constant> {
1629 inventory::iter::<Constant>().collect()
1630}
1631
1632#[cfg(target_arch = "wasm32")]
1633pub fn constants() -> Vec<&'static Constant> {
1634 wasm_registry::constants()
1635}
1636
1637#[derive(Debug)]
1642pub struct BuiltinDoc {
1643 pub name: &'static str,
1644 pub category: Option<&'static str>,
1645 pub summary: Option<&'static str>,
1646 pub keywords: Option<&'static str>,
1647 pub errors: Option<&'static str>,
1648 pub related: Option<&'static str>,
1649 pub introduced: Option<&'static str>,
1650 pub status: Option<&'static str>,
1651 pub examples: Option<&'static str>,
1652}
1653
1654#[cfg(not(target_arch = "wasm32"))]
1655inventory::collect!(BuiltinDoc);
1656
1657#[cfg(not(target_arch = "wasm32"))]
1658pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
1659 inventory::iter::<BuiltinDoc>().collect()
1660}
1661
1662#[cfg(target_arch = "wasm32")]
1663pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
1664 wasm_registry::builtin_docs()
1665}
1666
1667#[derive(Debug, Clone, Copy, PartialEq, Default)]
1673pub enum FormatMode {
1674 #[default]
1676 Short,
1677 Long,
1679 ShortE,
1681 LongE,
1683 ShortG,
1685 LongG,
1687 Rational,
1689 Hex,
1691}
1692
1693runmat_thread_local! {
1694 static DISPLAY_FORMAT: RefCell<FormatMode> = const { RefCell::new(FormatMode::Short) };
1695}
1696
1697pub fn set_display_format(mode: FormatMode) {
1698 DISPLAY_FORMAT.with(|c| *c.borrow_mut() = mode);
1699}
1700
1701pub fn get_display_format() -> FormatMode {
1702 DISPLAY_FORMAT.with(|c| *c.borrow())
1703}
1704
1705pub fn format_number(value: f64) -> String {
1707 if value.is_nan() {
1708 return "NaN".to_string();
1709 }
1710 if value.is_infinite() {
1711 return if value.is_sign_negative() {
1712 "-Inf"
1713 } else {
1714 "Inf"
1715 }
1716 .to_string();
1717 }
1718 let mode = get_display_format();
1719 if mode == FormatMode::Hex {
1720 return fmt_hex(value);
1721 }
1722 let v = if value == 0.0 { 0.0 } else { value };
1723 match mode {
1724 FormatMode::Short => fmt_short(v),
1725 FormatMode::Long => fmt_long(v),
1726 FormatMode::ShortE => fmt_sci(v, 4),
1727 FormatMode::LongE => fmt_sci(v, 14),
1728 FormatMode::ShortG => fmt_compact(v, 5),
1729 FormatMode::LongG => fmt_compact(v, 15),
1730 FormatMode::Rational => fmt_rational(v),
1731 FormatMode::Hex => unreachable!("hex mode handled before zero normalization"),
1732 }
1733}
1734
1735fn matlab_exp(s: &str) -> String {
1737 if let Some(e_pos) = s.find('e') {
1738 let mantissa = &s[..e_pos];
1739 let exp: i32 = s[e_pos + 1..].parse().unwrap_or(0);
1740 let sign = if exp >= 0 { '+' } else { '-' };
1741 format!("{mantissa}e{sign}{:02}", exp.unsigned_abs())
1742 } else {
1743 s.to_string()
1744 }
1745}
1746
1747fn fmt_sci(v: f64, dec: usize) -> String {
1748 if v == 0.0 {
1749 return format!("0.{:0>dec$}e+00", 0, dec = dec);
1750 }
1751 let s = format!("{v:.dec$e}");
1752 matlab_exp(&s)
1753}
1754
1755fn fmt_short(v: f64) -> String {
1756 let abs = v.abs();
1757 if abs == 0.0 {
1758 return "0".to_string();
1759 }
1760 if v.fract() == 0.0 && abs < 1e15 {
1761 return format!("{}", v as i64);
1762 }
1763 if (0.001..10000.0).contains(&abs) {
1764 format!("{:.4}", v)
1765 } else {
1766 fmt_sci(v, 4)
1767 }
1768}
1769
1770fn fmt_long(v: f64) -> String {
1771 let abs = v.abs();
1772 if abs == 0.0 {
1773 return "0".to_string();
1774 }
1775 if v.fract() == 0.0 && abs < 1e15 {
1776 return format!("{}", v as i64);
1777 }
1778 if (0.001..10000.0).contains(&abs) {
1779 format!("{:.15}", v)
1780 } else {
1781 fmt_sci(v, 14)
1782 }
1783}
1784
1785fn fmt_compact(v: f64, sig_digits: usize) -> String {
1786 let abs = v.abs();
1787 if abs == 0.0 {
1788 return "0".to_string();
1789 }
1790 let use_scientific = !(1e-4..1e6).contains(&abs);
1791 if use_scientific {
1792 let dec = sig_digits - 1;
1793 let s = format!("{v:.dec$e}");
1794 if let Some(e_pos) = s.find('e') {
1796 let exp_part = &s[e_pos..];
1797 let mut mantissa = s[..e_pos].to_string();
1798 if let Some(dot) = mantissa.find('.') {
1799 let mut end = mantissa.len();
1800 while end > dot + 1 && mantissa.as_bytes()[end - 1] == b'0' {
1801 end -= 1;
1802 }
1803 if mantissa.as_bytes()[end - 1] == b'.' {
1804 end -= 1;
1805 }
1806 mantissa.truncate(end);
1807 }
1808 return matlab_exp(&format!("{mantissa}{exp_part}"));
1809 }
1810 return matlab_exp(&s);
1811 }
1812 let exp10 = abs.log10().floor() as i32;
1813 let decimals = ((sig_digits as i32 - 1 - exp10).max(0)) as usize;
1814 let pow = 10f64.powi(decimals as i32);
1815 let rounded = (v * pow).round() / pow;
1816 let mut s = format!("{rounded:.decimals$}");
1817 if let Some(dot) = s.find('.') {
1818 let mut end = s.len();
1819 while end > dot + 1 && s.as_bytes()[end - 1] == b'0' {
1820 end -= 1;
1821 }
1822 if s.as_bytes()[end - 1] == b'.' {
1823 end -= 1;
1824 }
1825 s.truncate(end);
1826 }
1827 if s.is_empty() || s == "-0" {
1828 s = "0".to_string();
1829 }
1830 s
1831}
1832
1833fn fmt_rational(v: f64) -> String {
1834 if v == 0.0 {
1835 return "0".to_string();
1836 }
1837 let negative = v < 0.0;
1838 let abs = v.abs();
1839 if v.fract() == 0.0 && abs < 1e15 {
1840 return format!("{}", v as i64);
1841 }
1842 let tol = 5e-7 * abs;
1845 let max_d = 1_000_000i64;
1846 let mut n0: i64 = 1;
1847 let mut n1: i64 = abs.floor() as i64;
1848 let mut d0: i64 = 0;
1849 let mut d1: i64 = 1;
1850 let mut a = abs;
1851 let mut best_n = n1;
1852 let mut best_d = d1;
1853 for _ in 0..50 {
1854 if (abs - best_n as f64 / best_d as f64).abs() <= tol {
1855 break;
1856 }
1857 let f = a.fract();
1858 if f < 1e-10 {
1859 break;
1860 }
1861 a = 1.0 / f;
1862 let q = a.floor() as i64;
1863 let Some(n2) = q.checked_mul(n1).and_then(|v| v.checked_add(n0)) else {
1864 break;
1865 };
1866 let Some(d2) = q.checked_mul(d1).and_then(|v| v.checked_add(d0)) else {
1867 break;
1868 };
1869 if d2 > max_d {
1870 break;
1871 }
1872 best_n = n2;
1873 best_d = d2;
1874 n0 = n1;
1875 n1 = n2;
1876 d0 = d1;
1877 d1 = d2;
1878 }
1879 let sign = if negative { "-" } else { "" };
1880 if best_d == 1 {
1881 format!("{sign}{best_n}")
1882 } else {
1883 format!("{sign}{best_n}/{best_d}")
1884 }
1885}
1886
1887fn fmt_hex(v: f64) -> String {
1888 format!("{:016x}", v.to_bits())
1889}
1890
1891#[derive(Debug, Clone, PartialEq)]
1893pub struct MException {
1894 pub identifier: String,
1895 pub message: String,
1896 pub stack: Vec<String>,
1897}
1898
1899impl MException {
1900 pub fn new(identifier: String, message: String) -> Self {
1901 Self {
1902 identifier,
1903 message,
1904 stack: Vec::new(),
1905 }
1906 }
1907}
1908
1909#[derive(Debug, Clone)]
1911pub struct HandleRef {
1912 pub class_name: String,
1913 pub target: GcPtr<Value>,
1914 pub valid: bool,
1915}
1916
1917const HANDLE_VALID_FLAG_PROPERTY: &str = "__runmat_handle_valid__";
1918
1919pub fn is_handle_valid(handle: &HandleRef) -> bool {
1920 if !handle.valid {
1921 return false;
1922 }
1923 let raw = unsafe { handle.target.as_raw() };
1924 if raw.is_null() {
1925 return false;
1926 }
1927 match unsafe { &*raw } {
1928 Value::Object(obj) => !matches!(
1929 obj.properties.get(HANDLE_VALID_FLAG_PROPERTY),
1930 Some(Value::Bool(false))
1931 ),
1932 _ => true,
1933 }
1934}
1935
1936pub fn set_handle_valid(handle: &HandleRef, valid: bool) -> bool {
1937 let raw = unsafe { handle.target.as_raw_mut() };
1938 if raw.is_null() {
1939 return false;
1940 }
1941 match unsafe { &mut *raw } {
1942 Value::Object(obj) => {
1943 obj.properties
1944 .insert(HANDLE_VALID_FLAG_PROPERTY.to_string(), Value::Bool(valid));
1945 true
1946 }
1947 _ => false,
1948 }
1949}
1950
1951impl PartialEq for HandleRef {
1952 fn eq(&self, other: &Self) -> bool {
1953 let a = unsafe { self.target.as_raw() } as usize;
1954 let b = unsafe { other.target.as_raw() } as usize;
1955 a == b
1956 }
1957}
1958
1959#[derive(Debug, Clone, PartialEq)]
1961pub struct Listener {
1962 pub id: u64,
1963 pub target: GcPtr<Value>,
1964 pub event_name: String,
1965 pub callback: GcPtr<Value>,
1966 pub enabled: bool,
1967 pub valid: bool,
1968}
1969
1970impl Listener {
1971 pub fn class_name(&self) -> String {
1972 match unsafe { &*self.target.as_raw() } {
1973 Value::Object(o) => o.class_name.clone(),
1974 Value::HandleObject(h) => h.class_name.clone(),
1975 _ => String::new(),
1976 }
1977 }
1978}
1979
1980impl fmt::Display for Value {
1981 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1982 match self {
1983 Value::Int(i) => write!(f, "{}", i.to_i64()),
1984 Value::Num(n) => write!(f, "{}", format_number(*n)),
1985 Value::Complex(re, im) => {
1986 if *im == 0.0 {
1987 write!(f, "{}", format_number(*re))
1988 } else if *re == 0.0 {
1989 write!(f, "{}i", format_number(*im))
1990 } else if *im < 0.0 {
1991 write!(f, "{}-{}i", format_number(*re), format_number(im.abs()))
1992 } else {
1993 write!(f, "{}+{}i", format_number(*re), format_number(*im))
1994 }
1995 }
1996 Value::Bool(b) => write!(f, "{}", if *b { 1 } else { 0 }),
1997 Value::LogicalArray(la) => write!(f, "{la}"),
1998 Value::String(s) => write!(f, "'{s}'"),
1999 Value::StringArray(sa) => write!(f, "{sa}"),
2000 Value::CharArray(ca) => write!(f, "{ca}"),
2001 Value::Tensor(m) => write!(f, "{m}"),
2002 Value::ComplexTensor(m) => write!(f, "{m}"),
2003 Value::Cell(ca) => ca.fmt(f),
2004
2005 Value::GpuTensor(h) => write!(
2006 f,
2007 "GpuTensor(shape={:?}, device={}, buffer={})",
2008 h.shape, h.device_id, h.buffer_id
2009 ),
2010 Value::Object(obj) => write!(f, "{}(props={})", obj.class_name, obj.properties.len()),
2011 Value::HandleObject(h) => {
2012 let ptr = unsafe { h.target.as_raw() } as usize;
2013 write!(
2014 f,
2015 "<handle {} @0x{:x} valid={}>",
2016 h.class_name, ptr, h.valid
2017 )
2018 }
2019 Value::Listener(l) => {
2020 let ptr = unsafe { l.target.as_raw() } as usize;
2021 write!(
2022 f,
2023 "<listener id={} {}@0x{:x} '{}' enabled={} valid={}>",
2024 l.id,
2025 l.class_name(),
2026 ptr,
2027 l.event_name,
2028 l.enabled,
2029 l.valid
2030 )
2031 }
2032 Value::Struct(st) => {
2033 write!(f, "struct {{")?;
2034 for (i, (key, val)) in st.fields.iter().enumerate() {
2035 if i > 0 {
2036 write!(f, ", ")?;
2037 }
2038 write!(f, "{}: {}", key, val)?;
2039 }
2040 write!(f, "}}")
2041 }
2042 Value::OutputList(values) => {
2043 write!(f, "[")?;
2044 for (i, value) in values.iter().enumerate() {
2045 if i > 0 {
2046 write!(f, ", ")?;
2047 }
2048 write!(f, "{}", value)?;
2049 }
2050 write!(f, "]")
2051 }
2052 Value::FunctionHandle(name)
2053 | Value::ExternalFunctionHandle(name)
2054 | Value::MethodFunctionHandle(name) => {
2055 write!(f, "@{name}")
2056 }
2057 Value::BoundFunctionHandle { name, .. } => write!(f, "@{name}"),
2058 Value::Closure(c) => write!(
2059 f,
2060 "<closure {} captures={}>",
2061 c.function_name,
2062 c.captures.len()
2063 ),
2064 Value::ClassRef(name) => write!(f, "<class {name}>"),
2065 Value::MException(e) => write!(
2066 f,
2067 "MException(identifier='{}', message='{}')",
2068 e.identifier, e.message
2069 ),
2070 }
2071 }
2072}
2073
2074impl fmt::Display for ComplexTensor {
2075 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2076 match self.shape.len() {
2077 0 | 1 => {
2078 write!(f, "[")?;
2079 for (i, (re, im)) in self.data.iter().enumerate() {
2080 if i > 0 {
2081 write!(f, " ")?;
2082 }
2083 let s = Value::Complex(*re, *im).to_string();
2084 write!(f, "{s}")?;
2085 }
2086 write!(f, "]")
2087 }
2088 2 => {
2089 let rows = self.rows;
2090 let cols = self.cols;
2091 write!(f, "[")?;
2092 for r in 0..rows {
2093 for c in 0..cols {
2094 if c > 0 {
2095 write!(f, " ")?;
2096 }
2097 let (re, im) = self.data[r + c * rows];
2098 let s = Value::Complex(re, im).to_string();
2099 write!(f, "{s}")?;
2100 }
2101 if r + 1 < rows {
2102 write!(f, "; ")?;
2103 }
2104 }
2105 write!(f, "]")
2106 }
2107 _ => {
2108 if should_expand_nd_display(&self.shape) {
2109 write_nd_pages(f, &self.shape, |f, idx| {
2110 let (re, im) = self.data[idx];
2111 write!(f, "{}", Value::Complex(re, im))
2112 })
2113 } else {
2114 write!(f, "ComplexTensor(shape={:?})", self.shape)
2115 }
2116 }
2117 }
2118 }
2119}
2120
2121#[cfg(test)]
2122mod display_tests {
2123 use super::{
2124 fmt_rational, format_number, set_display_format, ComplexTensor, FormatMode, LogicalArray,
2125 Tensor,
2126 };
2127
2128 #[test]
2129 fn fmt_rational_large_value_with_tiny_fract_does_not_overflow() {
2130 let result = std::panic::catch_unwind(|| fmt_rational(1_000_000_000_000_000.000_1));
2133 assert!(
2134 result.is_ok(),
2135 "fmt_rational panicked on large value with tiny fract"
2136 );
2137
2138 let result = std::panic::catch_unwind(|| fmt_rational(-1_000_000_000_000_000.000_1));
2140 assert!(
2141 result.is_ok(),
2142 "fmt_rational panicked on negative large value with tiny fract"
2143 );
2144 }
2145
2146 #[test]
2147 fn tensor_nd_display_uses_page_headers() {
2148 let tensor = Tensor::new(
2149 vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2150 vec![2, 3, 2],
2151 )
2152 .expect("tensor");
2153 let rendered = tensor.to_string();
2154 assert!(rendered.contains("(:, :, 1) ="));
2155 assert!(rendered.contains("(:, :, 2) ="));
2156 assert!(rendered.contains(" 1 0 0"));
2157 }
2158
2159 #[test]
2160 fn tensor_nd_display_falls_back_for_large_arrays() {
2161 let tensor = Tensor::new(vec![0.0; 4097], vec![1, 1, 4097]).expect("tensor");
2162 assert_eq!(tensor.to_string(), "Tensor(shape=[1, 1, 4097])");
2163 }
2164
2165 #[test]
2166 fn logical_nd_display_uses_headers_and_fallback_summary() {
2167 let logical =
2168 LogicalArray::new(vec![1, 0, 0, 1, 1, 0, 0, 1], vec![2, 2, 2]).expect("logical");
2169 let rendered = logical.to_string();
2170 assert!(rendered.contains("(:, :, 1) ="));
2171 assert!(rendered.contains("(:, :, 2) ="));
2172
2173 let large = LogicalArray::new(vec![1; 4097], vec![1, 1, 4097]).expect("large logical");
2174 assert_eq!(large.to_string(), "1x1x4097 logical array");
2175 }
2176
2177 #[test]
2178 fn complex_nd_display_uses_page_headers() {
2179 let complex = ComplexTensor::new(
2180 vec![(1.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0)],
2181 vec![2, 1, 2],
2182 )
2183 .expect("complex");
2184 let rendered = complex.to_string();
2185 assert!(rendered.contains("(:, :, 1) ="));
2186 assert!(rendered.contains("(:, :, 2) ="));
2187 }
2188
2189 #[test]
2190 fn format_hex_preserves_negative_zero_sign_bit() {
2191 set_display_format(FormatMode::Hex);
2192 assert_eq!(format_number(-0.0), "8000000000000000");
2193 assert_eq!(format_number(0.0), "0000000000000000");
2194 set_display_format(FormatMode::Short);
2195 }
2196}
2197
2198#[derive(Debug, Clone, PartialEq)]
2199pub struct CellArray {
2200 pub data: Vec<GcPtr<Value>>,
2201 pub shape: Vec<usize>,
2203 pub rows: usize,
2205 pub cols: usize,
2207}
2208
2209impl CellArray {
2210 pub fn new_handles(
2211 handles: Vec<GcPtr<Value>>,
2212 rows: usize,
2213 cols: usize,
2214 ) -> Result<Self, String> {
2215 Self::new_handles_with_shape(handles, vec![rows, cols])
2216 }
2217
2218 pub fn new_handles_with_shape(
2219 handles: Vec<GcPtr<Value>>,
2220 shape: Vec<usize>,
2221 ) -> Result<Self, String> {
2222 let expected = total_len(&shape)
2223 .ok_or_else(|| "Cell data shape exceeds platform limits".to_string())?;
2224 if expected != handles.len() {
2225 return Err(format!(
2226 "Cell data length {} doesn't match shape {:?} ({} elements)",
2227 handles.len(),
2228 shape,
2229 expected
2230 ));
2231 }
2232 let (rows, cols) = shape_rows_cols(&shape);
2233 Ok(CellArray {
2234 data: handles,
2235 shape,
2236 rows,
2237 cols,
2238 })
2239 }
2240
2241 pub fn new(data: Vec<Value>, rows: usize, cols: usize) -> Result<Self, String> {
2242 Self::new_with_shape(data, vec![rows, cols])
2243 }
2244
2245 pub fn new_with_shape(data: Vec<Value>, shape: Vec<usize>) -> Result<Self, String> {
2246 let expected = total_len(&shape)
2247 .ok_or_else(|| "Cell data shape exceeds platform limits".to_string())?;
2248 if expected != data.len() {
2249 return Err(format!(
2250 "Cell data length {} doesn't match shape {:?} ({} elements)",
2251 data.len(),
2252 shape,
2253 expected
2254 ));
2255 }
2256 let handles: Vec<GcPtr<Value>> = data
2258 .into_iter()
2259 .map(|v| unsafe { GcPtr::from_raw(Box::into_raw(Box::new(v))) })
2260 .collect();
2261 Self::new_handles_with_shape(handles, shape)
2262 }
2263
2264 pub fn get(&self, row: usize, col: usize) -> Result<Value, String> {
2265 if row >= self.rows || col >= self.cols {
2266 return Err(format!(
2267 "Cell index ({row}, {col}) out of bounds for {}x{} cell array",
2268 self.rows, self.cols
2269 ));
2270 }
2271 Ok((*self.data[row * self.cols + col]).clone())
2272 }
2273}
2274
2275fn total_len(shape: &[usize]) -> Option<usize> {
2276 if shape.is_empty() {
2277 return Some(0);
2278 }
2279 shape
2280 .iter()
2281 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
2282}
2283
2284fn shape_rows_cols(shape: &[usize]) -> (usize, usize) {
2285 if shape.is_empty() {
2286 return (0, 0);
2287 }
2288 if shape.len() == 1 {
2289 return (1, shape[0]);
2290 }
2291 (shape[0], shape[1])
2292}
2293
2294impl fmt::Display for CellArray {
2295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2296 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
2297 if self.shape.len() > 2 {
2298 return write!(f, "{} cell array", dims.join("x"));
2299 }
2300 write!(f, "{}x{} cell array", self.rows, self.cols)?;
2301 if self.rows == 0 || self.cols == 0 {
2302 return Ok(());
2303 }
2304 for r in 0..self.rows {
2305 writeln!(f)?;
2306 write!(f, " ")?;
2307 for c in 0..self.cols {
2308 if c > 0 {
2309 write!(f, " ")?;
2310 }
2311 let value = self.get(r, c).unwrap_or_else(|_| Value::Num(f64::NAN));
2312 write!(f, "{{{value}}}")?;
2313 }
2314 }
2315 Ok(())
2316 }
2317}
2318
2319#[derive(Debug, Clone, PartialEq)]
2320pub struct ObjectInstance {
2321 pub class_name: String,
2322 pub properties: HashMap<String, Value>,
2323}
2324
2325impl ObjectInstance {
2326 pub fn new(class_name: String) -> Self {
2327 Self {
2328 class_name,
2329 properties: HashMap::new(),
2330 }
2331 }
2332
2333 pub fn is_class(&self, name: &str) -> bool {
2334 self.class_name == name
2335 }
2336}
2337
2338#[derive(Debug, Clone, PartialEq, Eq)]
2340pub enum Access {
2341 Public,
2342 Private,
2343 Protected,
2344}
2345
2346#[derive(Debug, Clone)]
2347pub struct PropertyDef {
2348 pub name: String,
2349 pub is_static: bool,
2350 pub is_constant: bool,
2351 pub is_dependent: bool,
2352 pub get_access: Access,
2353 pub set_access: Access,
2354 pub default_value: Option<Value>,
2355}
2356
2357#[derive(Debug, Clone)]
2358pub struct MethodDef {
2359 pub name: String,
2360 pub is_static: bool,
2361 pub is_abstract: bool,
2362 pub is_sealed: bool,
2363 pub access: Access,
2364 pub function_name: String, pub implicit_class_argument: Option<String>,
2366}
2367
2368#[derive(Debug, Clone)]
2369pub struct ClassDef {
2370 pub name: String, pub parent: Option<String>,
2372 pub properties: HashMap<String, PropertyDef>,
2373 pub methods: HashMap<String, MethodDef>,
2374}
2375
2376use std::sync::Mutex;
2377
2378static CLASS_REGISTRY: OnceLock<Mutex<HashMap<String, ClassDef>>> = OnceLock::new();
2379static SEALED_CLASS_REGISTRY: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
2380static ABSTRACT_CLASS_REGISTRY: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
2381static STATIC_VALUES: OnceLock<Mutex<HashMap<(String, String), Value>>> = OnceLock::new();
2382static ENUMERATION_REGISTRY: OnceLock<Mutex<HashMap<String, HashSet<String>>>> = OnceLock::new();
2383
2384fn registry() -> &'static Mutex<HashMap<String, ClassDef>> {
2385 CLASS_REGISTRY.get_or_init(|| Mutex::new(primitive_class_registry()))
2386}
2387
2388fn sealed_registry() -> &'static Mutex<HashSet<String>> {
2389 SEALED_CLASS_REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
2390}
2391
2392fn abstract_registry() -> &'static Mutex<HashSet<String>> {
2393 ABSTRACT_CLASS_REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
2394}
2395
2396fn enumeration_registry() -> &'static Mutex<HashMap<String, HashSet<String>>> {
2397 ENUMERATION_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
2398}
2399
2400fn primitive_class_registry() -> HashMap<String, ClassDef> {
2401 ["double", "single", "logical"]
2402 .into_iter()
2403 .map(|class_name| {
2404 let mut methods = HashMap::new();
2405 methods.insert(
2406 "zeros".to_string(),
2407 MethodDef {
2408 name: "zeros".to_string(),
2409 is_static: true,
2410 is_abstract: false,
2411 is_sealed: false,
2412 access: Access::Public,
2413 function_name: "zeros".to_string(),
2414 implicit_class_argument: Some(class_name.to_string()),
2415 },
2416 );
2417 (
2418 class_name.to_string(),
2419 ClassDef {
2420 name: class_name.to_string(),
2421 parent: None,
2422 properties: HashMap::new(),
2423 methods,
2424 },
2425 )
2426 })
2427 .collect()
2428}
2429
2430pub fn register_class(def: ClassDef) {
2431 register_class_with_modifiers(def, false, false);
2432}
2433
2434pub fn register_class_with_sealed(def: ClassDef, is_sealed: bool) {
2435 register_class_with_modifiers(def, is_sealed, false);
2436}
2437
2438pub fn register_class_with_modifiers(def: ClassDef, is_sealed: bool, is_abstract: bool) {
2439 let mut m = registry().lock().unwrap();
2440 let class_name = def.name.clone();
2441 m.insert(class_name.clone(), def);
2442 let mut sealed = sealed_registry().lock().unwrap();
2443 if is_sealed {
2444 sealed.insert(class_name.clone());
2445 } else {
2446 sealed.remove(&class_name);
2447 }
2448 let mut abstract_classes = abstract_registry().lock().unwrap();
2449 if is_abstract {
2450 abstract_classes.insert(class_name.clone());
2451 } else {
2452 abstract_classes.remove(&class_name);
2453 }
2454 enumeration_registry()
2455 .lock()
2456 .unwrap()
2457 .entry(class_name)
2458 .or_default();
2459}
2460
2461pub fn register_class_enumerations(class_name: &str, members: impl IntoIterator<Item = String>) {
2462 let mut registry = enumeration_registry().lock().unwrap();
2463 let entry = registry.entry(class_name.to_string()).or_default();
2464 entry.clear();
2465 entry.extend(members);
2466}
2467
2468pub fn class_has_enumeration_member(class_name: &str, member: &str) -> bool {
2469 enumeration_registry()
2470 .lock()
2471 .unwrap()
2472 .get(class_name)
2473 .is_some_and(|members| members.contains(member))
2474}
2475
2476pub fn get_class(name: &str) -> Option<ClassDef> {
2477 registry().lock().unwrap().get(name).cloned()
2478}
2479
2480pub fn class_names() -> Vec<String> {
2481 registry().lock().unwrap().keys().cloned().collect()
2482}
2483
2484pub fn is_class_sealed(name: &str) -> bool {
2485 sealed_registry().lock().unwrap().contains(name)
2486}
2487
2488pub fn is_class_abstract(name: &str) -> bool {
2489 abstract_registry().lock().unwrap().contains(name)
2490}
2491
2492pub fn is_class_or_subclass(class_name: &str, ancestor_name: &str) -> bool {
2493 if class_name == ancestor_name {
2494 return true;
2495 }
2496 let reg = registry().lock().unwrap();
2497 let mut current = Some(class_name.to_string());
2498 let mut visited = std::collections::HashSet::new();
2499 while let Some(name) = current {
2500 if !visited.insert(name.clone()) {
2501 break;
2502 }
2503 if name == ancestor_name {
2504 return true;
2505 }
2506 current = reg
2507 .get(&name)
2508 .and_then(|class_def| class_def.parent.clone());
2509 }
2510 false
2511}
2512
2513pub fn lookup_property(class_name: &str, prop: &str) -> Option<(PropertyDef, String)> {
2516 let reg = registry().lock().unwrap();
2517 let mut current = Some(class_name.to_string());
2518 let mut visited = std::collections::HashSet::new();
2519 while let Some(name) = current {
2520 if !visited.insert(name.clone()) {
2521 break;
2522 }
2523 if let Some(cls) = reg.get(&name) {
2524 if let Some(p) = cls.properties.get(prop) {
2525 return Some((p.clone(), name));
2526 }
2527 current = cls.parent.clone();
2528 } else {
2529 break;
2530 }
2531 }
2532 None
2533}
2534
2535pub fn lookup_method(class_name: &str, method: &str) -> Option<(MethodDef, String)> {
2538 let reg = registry().lock().unwrap();
2539 let mut current = Some(class_name.to_string());
2540 let mut visited = std::collections::HashSet::new();
2541 while let Some(name) = current {
2542 if !visited.insert(name.clone()) {
2543 break;
2544 }
2545 if let Some(cls) = reg.get(&name) {
2546 if let Some(m) = cls.methods.get(method) {
2547 return Some((m.clone(), name));
2548 }
2549 current = cls.parent.clone();
2550 } else {
2551 break;
2552 }
2553 }
2554 None
2555}
2556
2557fn static_values() -> &'static Mutex<HashMap<(String, String), Value>> {
2558 STATIC_VALUES.get_or_init(|| Mutex::new(HashMap::new()))
2559}
2560
2561pub fn get_static_property_value(class_name: &str, prop: &str) -> Option<Value> {
2562 static_values()
2563 .lock()
2564 .unwrap()
2565 .get(&(class_name.to_string(), prop.to_string()))
2566 .cloned()
2567}
2568
2569pub fn set_static_property_value(class_name: &str, prop: &str, value: Value) {
2570 static_values()
2571 .lock()
2572 .unwrap()
2573 .insert((class_name.to_string(), prop.to_string()), value);
2574}
2575
2576pub fn set_static_property_value_in_owner(
2578 class_name: &str,
2579 prop: &str,
2580 value: Value,
2581) -> Result<(), String> {
2582 if let Some((_p, owner)) = lookup_property(class_name, prop) {
2583 set_static_property_value(&owner, prop, value);
2584 Ok(())
2585 } else {
2586 Err(format!("Unknown static property '{class_name}.{prop}'"))
2587 }
2588}
2589
2590#[cfg(test)]
2591mod class_registry_tests {
2592 use super::{
2593 get_class, lookup_method, lookup_property, register_class, Access, ClassDef, MethodDef,
2594 PropertyDef,
2595 };
2596 use std::collections::HashMap;
2597 use std::sync::atomic::{AtomicU64, Ordering};
2598
2599 static TEST_CLASS_COUNTER: AtomicU64 = AtomicU64::new(0);
2600
2601 fn unique_class_name(prefix: &str) -> String {
2602 let id = TEST_CLASS_COUNTER.fetch_add(1, Ordering::Relaxed);
2603 format!("{}_{}", prefix, id)
2604 }
2605
2606 #[test]
2607 fn primitive_classes_expose_static_zeros_method_metadata() {
2608 for class_name in ["double", "single", "logical"] {
2609 let class_def = get_class(class_name).expect("primitive class should be registered");
2610 let method = class_def
2611 .methods
2612 .get("zeros")
2613 .expect("primitive class should expose zeros static method");
2614 assert!(method.is_static, "zeros should be static on {class_name}");
2615 assert_eq!(method.function_name, "zeros");
2616 assert_eq!(method.implicit_class_argument.as_deref(), Some(class_name));
2617
2618 let (resolved, owner) =
2619 lookup_method(class_name, "zeros").expect("lookup should find primitive zeros");
2620 assert_eq!(owner, class_name);
2621 assert_eq!(resolved.function_name, "zeros");
2622 assert_eq!(
2623 resolved.implicit_class_argument.as_deref(),
2624 Some(class_name)
2625 );
2626 }
2627 }
2628
2629 #[test]
2630 fn method_lookup_uses_parent_class_metadata_chain() {
2631 let parent_name = unique_class_name("plan6_parent");
2632 let child_name = unique_class_name("plan6_child");
2633
2634 let mut parent_methods = HashMap::new();
2635 parent_methods.insert(
2636 "parentOnly".to_string(),
2637 MethodDef {
2638 name: "parentOnly".to_string(),
2639 is_static: false,
2640 is_abstract: false,
2641 is_sealed: false,
2642 access: Access::Public,
2643 function_name: "parentOnly_impl".to_string(),
2644 implicit_class_argument: None,
2645 },
2646 );
2647 register_class(ClassDef {
2648 name: parent_name.clone(),
2649 parent: None,
2650 properties: HashMap::new(),
2651 methods: parent_methods,
2652 });
2653 register_class(ClassDef {
2654 name: child_name.clone(),
2655 parent: Some(parent_name.clone()),
2656 properties: HashMap::new(),
2657 methods: HashMap::new(),
2658 });
2659
2660 let (method, owner) = lookup_method(&child_name, "parentOnly")
2661 .expect("child lookup should resolve inherited method through parent metadata");
2662 assert_eq!(owner, parent_name);
2663 assert_eq!(method.function_name, "parentOnly_impl");
2664 }
2665
2666 #[test]
2667 fn method_lookup_handles_parent_cycle() {
2668 let class_a = unique_class_name("plan6_cycle_method_a");
2669 let class_b = unique_class_name("plan6_cycle_method_b");
2670
2671 register_class(ClassDef {
2672 name: class_a.clone(),
2673 parent: Some(class_b.clone()),
2674 properties: HashMap::new(),
2675 methods: HashMap::new(),
2676 });
2677 register_class(ClassDef {
2678 name: class_b.clone(),
2679 parent: Some(class_a.clone()),
2680 properties: HashMap::new(),
2681 methods: HashMap::new(),
2682 });
2683
2684 assert!(
2685 lookup_method(&class_a, "missing").is_none(),
2686 "cyclic parent metadata should terminate missing method lookup"
2687 );
2688 }
2689
2690 #[test]
2691 fn property_lookup_uses_parent_class_metadata_chain() {
2692 let parent_name = unique_class_name("plan6_property_parent");
2693 let child_name = unique_class_name("plan6_property_child");
2694
2695 let mut parent_properties = HashMap::new();
2696 parent_properties.insert(
2697 "parentFlag".to_string(),
2698 PropertyDef {
2699 name: "parentFlag".to_string(),
2700 is_static: false,
2701 is_constant: false,
2702 is_dependent: false,
2703 get_access: Access::Public,
2704 set_access: Access::Public,
2705 default_value: None,
2706 },
2707 );
2708 register_class(ClassDef {
2709 name: parent_name.clone(),
2710 parent: None,
2711 properties: parent_properties,
2712 methods: HashMap::new(),
2713 });
2714 register_class(ClassDef {
2715 name: child_name.clone(),
2716 parent: Some(parent_name.clone()),
2717 properties: HashMap::new(),
2718 methods: HashMap::new(),
2719 });
2720
2721 let (property, owner) = lookup_property(&child_name, "parentFlag")
2722 .expect("child property lookup should resolve inherited property through parent");
2723 assert_eq!(owner, parent_name);
2724 assert_eq!(property.name, "parentFlag");
2725 assert!(!property.is_static);
2726 }
2727
2728 #[test]
2729 fn property_lookup_handles_parent_cycle() {
2730 let class_a = unique_class_name("plan6_cycle_property_a");
2731 let class_b = unique_class_name("plan6_cycle_property_b");
2732
2733 register_class(ClassDef {
2734 name: class_a.clone(),
2735 parent: Some(class_b.clone()),
2736 properties: HashMap::new(),
2737 methods: HashMap::new(),
2738 });
2739 register_class(ClassDef {
2740 name: class_b.clone(),
2741 parent: Some(class_a.clone()),
2742 properties: HashMap::new(),
2743 methods: HashMap::new(),
2744 });
2745
2746 assert!(
2747 lookup_property(&class_a, "missing").is_none(),
2748 "cyclic parent metadata should terminate missing property lookup"
2749 );
2750 }
2751}