1pub use inventory;
2use runmat_gc_api::GcPtr;
3use runmat_thread_local::runmat_thread_local;
4use std::cell::RefCell;
5use std::collections::HashMap;
6use std::collections::HashSet;
7use std::convert::TryFrom;
8use std::fmt;
9use std::future::Future;
10use std::pin::Pin;
11
12use indexmap::IndexMap;
13use std::sync::OnceLock;
14
15#[cfg(target_arch = "wasm32")]
16pub mod wasm_registry {
17 use super::{BuiltinDoc, BuiltinFunction, Constant};
18 use once_cell::sync::Lazy;
19 use std::sync::Mutex;
20
21 static FUNCTIONS: Lazy<Mutex<Vec<&'static BuiltinFunction>>> =
22 Lazy::new(|| Mutex::new(Vec::new()));
23 static CONSTANTS: Lazy<Mutex<Vec<&'static Constant>>> = Lazy::new(|| Mutex::new(Vec::new()));
24 static DOCS: Lazy<Mutex<Vec<&'static BuiltinDoc>>> = Lazy::new(|| Mutex::new(Vec::new()));
25 static REGISTERED: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
26
27 fn leak<T>(value: T) -> &'static T {
28 Box::leak(Box::new(value))
29 }
30
31 pub fn submit_builtin_function(func: BuiltinFunction) {
32 let leaked = leak(func);
33 FUNCTIONS.lock().unwrap().push(leaked);
34 }
35
36 pub fn submit_constant(constant: Constant) {
37 let leaked = leak(constant);
38 CONSTANTS.lock().unwrap().push(leaked);
39 }
40
41 pub fn submit_builtin_doc(doc: BuiltinDoc) {
42 let leaked = leak(doc);
43 DOCS.lock().unwrap().push(leaked);
44 }
45
46 pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
47 FUNCTIONS.lock().unwrap().clone()
48 }
49
50 pub fn constants() -> Vec<&'static Constant> {
51 CONSTANTS.lock().unwrap().clone()
52 }
53
54 pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
55 DOCS.lock().unwrap().clone()
56 }
57
58 pub fn mark_registered() {
59 *REGISTERED.lock().unwrap() = true;
60 }
61
62 pub fn is_registered() -> bool {
63 *REGISTERED.lock().unwrap()
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
68pub enum Value {
69 Int(IntValue),
70 Num(f64),
71 Complex(f64, f64),
73 Bool(bool),
74 LogicalArray(LogicalArray),
76 String(String),
77 StringArray(StringArray),
79 CharArray(CharArray),
81 Tensor(Tensor),
82 SparseTensor(SparseTensor),
84 ComplexTensor(ComplexTensor),
86 Cell(CellArray),
87 Struct(StructValue),
90 GpuTensor(runmat_accelerate_api::GpuTensorHandle),
92 Object(ObjectInstance),
94 HandleObject(HandleRef),
96 Listener(Listener),
98 OutputList(Vec<Value>),
100 FunctionHandle(String),
102 ExternalFunctionHandle(String),
104 MethodFunctionHandle(String),
106 BoundFunctionHandle {
108 name: String,
109 function: usize,
110 },
111 Closure(Closure),
112 ClassRef(String),
113 MException(MException),
114}
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum IntValue {
117 I8(i8),
118 I16(i16),
119 I32(i32),
120 I64(i64),
121 U8(u8),
122 U16(u16),
123 U32(u32),
124 U64(u64),
125}
126
127impl IntValue {
128 pub fn to_i64(&self) -> i64 {
129 match self {
130 IntValue::I8(v) => *v as i64,
131 IntValue::I16(v) => *v as i64,
132 IntValue::I32(v) => *v as i64,
133 IntValue::I64(v) => *v,
134 IntValue::U8(v) => *v as i64,
135 IntValue::U16(v) => *v as i64,
136 IntValue::U32(v) => *v as i64,
137 IntValue::U64(v) => {
138 if *v > i64::MAX as u64 {
139 i64::MAX
140 } else {
141 *v as i64
142 }
143 }
144 }
145 }
146 pub fn to_f64(&self) -> f64 {
147 self.to_i64() as f64
148 }
149 pub fn is_zero(&self) -> bool {
150 self.to_i64() == 0
151 }
152 pub fn class_name(&self) -> &'static str {
153 match self {
154 IntValue::I8(_) => "int8",
155 IntValue::I16(_) => "int16",
156 IntValue::I32(_) => "int32",
157 IntValue::I64(_) => "int64",
158 IntValue::U8(_) => "uint8",
159 IntValue::U16(_) => "uint16",
160 IntValue::U32(_) => "uint32",
161 IntValue::U64(_) => "uint64",
162 }
163 }
164}
165
166#[derive(Debug, Clone, PartialEq)]
167pub struct StructValue {
168 pub fields: IndexMap<String, Value>,
169}
170
171impl StructValue {
172 pub fn new() -> Self {
173 Self {
174 fields: IndexMap::new(),
175 }
176 }
177
178 pub fn insert(&mut self, name: impl Into<String>, value: Value) -> Option<Value> {
180 self.fields.insert(name.into(), value)
181 }
182
183 pub fn remove(&mut self, name: &str) -> Option<Value> {
185 self.fields.shift_remove(name)
186 }
187
188 pub fn field_names(&self) -> impl Iterator<Item = &String> {
190 self.fields.keys()
191 }
192}
193
194impl Default for StructValue {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
201pub enum NumericDType {
202 F64,
203 F32,
204 U8,
205 U16,
206}
207
208impl NumericDType {
209 pub fn class_name(self) -> &'static str {
210 match self {
211 NumericDType::F64 => "double",
212 NumericDType::F32 => "single",
213 NumericDType::U8 => "uint8",
214 NumericDType::U16 => "uint16",
215 }
216 }
217
218 pub fn byte_size(self) -> usize {
219 match self {
220 NumericDType::F64 => 8,
221 NumericDType::F32 => 4,
222 NumericDType::U8 => 1,
223 NumericDType::U16 => 2,
224 }
225 }
226}
227
228#[derive(Debug, Clone, PartialEq)]
229pub struct Tensor {
230 pub data: Vec<f64>,
231 pub shape: Vec<usize>, pub rows: usize, pub cols: usize, pub dtype: NumericDType,
236}
237
238#[derive(Debug, Clone, PartialEq)]
239pub struct SparseTensor {
240 pub rows: usize,
241 pub cols: usize,
242 pub col_ptrs: Vec<usize>,
244 pub row_indices: Vec<usize>,
246 pub values: Vec<f64>,
247}
248
249#[derive(Debug, Clone, PartialEq)]
250pub struct ComplexTensor {
251 pub data: Vec<(f64, f64)>,
252 pub shape: Vec<usize>,
253 pub rows: usize,
254 pub cols: usize,
255}
256
257#[derive(Debug, Clone, PartialEq)]
258pub struct StringArray {
259 pub data: Vec<String>,
260 pub shape: Vec<usize>,
261 pub rows: usize,
262 pub cols: usize,
263}
264
265#[derive(Debug, Clone, PartialEq)]
266pub struct LogicalArray {
267 pub data: Vec<u8>, pub shape: Vec<usize>,
269}
270
271impl LogicalArray {
272 pub fn new(data: Vec<u8>, shape: Vec<usize>) -> Result<Self, String> {
273 let expected: usize = shape.iter().product();
274 if data.len() != expected {
275 return Err(format!(
276 "LogicalArray data length {} doesn't match shape {:?} ({} elements)",
277 data.len(),
278 shape,
279 expected
280 ));
281 }
282 let mut d = data;
284 for v in &mut d {
285 *v = if *v != 0 { 1 } else { 0 };
286 }
287 Ok(LogicalArray { data: d, shape })
288 }
289 pub fn zeros(shape: Vec<usize>) -> Self {
290 let expected: usize = shape.iter().product();
291 LogicalArray {
292 data: vec![0u8; expected],
293 shape,
294 }
295 }
296 pub fn len(&self) -> usize {
297 self.data.len()
298 }
299 pub fn is_empty(&self) -> bool {
300 self.data.is_empty()
301 }
302}
303
304#[derive(Debug, Clone, PartialEq)]
305pub struct CharArray {
306 pub data: Vec<char>,
307 pub rows: usize,
308 pub cols: usize,
309}
310
311impl CharArray {
312 pub fn new_row(s: &str) -> Self {
313 CharArray {
314 data: s.chars().collect(),
315 rows: 1,
316 cols: s.chars().count(),
317 }
318 }
319 pub fn new(data: Vec<char>, rows: usize, cols: usize) -> Result<Self, String> {
320 if rows * cols != data.len() {
321 return Err(format!(
322 "Char data length {} doesn't match dimensions {}x{}",
323 data.len(),
324 rows,
325 cols
326 ));
327 }
328 Ok(CharArray { data, rows, cols })
329 }
330}
331
332impl StringArray {
333 pub fn new(data: Vec<String>, shape: Vec<usize>) -> Result<Self, String> {
334 let expected: usize = shape.iter().product();
335 if data.len() != expected {
336 return Err(format!(
337 "StringArray data length {} doesn't match shape {:?} ({} elements)",
338 data.len(),
339 shape,
340 expected
341 ));
342 }
343 let (rows, cols) = if shape.len() >= 2 {
344 (shape[0], shape[1])
345 } else if shape.len() == 1 {
346 (1, shape[0])
347 } else {
348 (0, 0)
349 };
350 Ok(StringArray {
351 data,
352 shape,
353 rows,
354 cols,
355 })
356 }
357 pub fn new_2d(data: Vec<String>, rows: usize, cols: usize) -> Result<Self, String> {
358 Self::new(data, vec![rows, cols])
359 }
360 pub fn rows(&self) -> usize {
361 self.shape.first().copied().unwrap_or(1)
362 }
363 pub fn cols(&self) -> usize {
364 self.shape.get(1).copied().unwrap_or(1)
365 }
366}
367
368impl Tensor {
371 pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Result<Self, String> {
372 let expected: usize = shape.iter().product();
373 if data.len() != expected {
374 return Err(format!(
375 "Tensor data length {} doesn't match shape {:?} ({} elements)",
376 data.len(),
377 shape,
378 expected
379 ));
380 }
381 let (rows, cols) = if shape.len() >= 2 {
382 (shape[0], shape[1])
383 } else if shape.len() == 1 {
384 (1, shape[0])
385 } else {
386 (0, 0)
387 };
388 Ok(Tensor {
389 data,
390 shape,
391 rows,
392 cols,
393 dtype: NumericDType::F64,
394 })
395 }
396
397 pub fn new_2d(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
398 Self::new(data, vec![rows, cols])
399 }
400
401 pub fn from_f32(data: Vec<f32>, shape: Vec<usize>) -> Result<Self, String> {
402 let converted: Vec<f64> = data.into_iter().map(|v| v as f64).collect();
403 Self::new_with_dtype(converted, shape, NumericDType::F32)
404 }
405
406 pub fn from_f32_slice(data: &[f32], shape: &[usize]) -> Result<Self, String> {
407 let converted: Vec<f64> = data.iter().map(|&v| v as f64).collect();
408 Self::new_with_dtype(converted, shape.to_vec(), NumericDType::F32)
409 }
410
411 pub fn new_with_dtype(
412 data: Vec<f64>,
413 shape: Vec<usize>,
414 dtype: NumericDType,
415 ) -> Result<Self, String> {
416 let mut t = Self::new(data, shape)?;
417 t.dtype = dtype;
418 Ok(t)
419 }
420
421 pub fn zeros(shape: Vec<usize>) -> Self {
422 let size: usize = shape.iter().product();
423 let (rows, cols) = if shape.len() >= 2 {
424 (shape[0], shape[1])
425 } else if shape.len() == 1 {
426 (1, shape[0])
427 } else {
428 (0, 0)
429 };
430 Tensor {
431 data: vec![0.0; size],
432 shape,
433 rows,
434 cols,
435 dtype: NumericDType::F64,
436 }
437 }
438
439 pub fn ones(shape: Vec<usize>) -> Self {
440 let size: usize = shape.iter().product();
441 let (rows, cols) = if shape.len() >= 2 {
442 (shape[0], shape[1])
443 } else if shape.len() == 1 {
444 (1, shape[0])
445 } else {
446 (0, 0)
447 };
448 Tensor {
449 data: vec![1.0; size],
450 shape,
451 rows,
452 cols,
453 dtype: NumericDType::F64,
454 }
455 }
456
457 pub fn zeros2(rows: usize, cols: usize) -> Self {
459 Self::zeros(vec![rows, cols])
460 }
461 pub fn ones2(rows: usize, cols: usize) -> Self {
462 Self::ones(vec![rows, cols])
463 }
464
465 pub fn rows(&self) -> usize {
466 self.shape.first().copied().unwrap_or(1)
467 }
468 pub fn cols(&self) -> usize {
469 self.shape.get(1).copied().unwrap_or(1)
470 }
471
472 pub fn get2(&self, row: usize, col: usize) -> Result<f64, String> {
473 let rows = self.rows();
474 let cols = self.cols();
475 if row >= rows || col >= cols {
476 return Err(format!(
477 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
478 ));
479 }
480 Ok(self.data[row + col * rows])
482 }
483
484 pub fn set2(&mut self, row: usize, col: usize, value: f64) -> Result<(), String> {
485 let rows = self.rows();
486 let cols = self.cols();
487 if row >= rows || col >= cols {
488 return Err(format!(
489 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
490 ));
491 }
492 self.data[row + col * rows] = value;
494 Ok(())
495 }
496
497 pub fn scalar_to_tensor2(scalar: f64, rows: usize, cols: usize) -> Tensor {
498 Tensor {
499 data: vec![scalar; rows * cols],
500 shape: vec![rows, cols],
501 rows,
502 cols,
503 dtype: NumericDType::F64,
504 }
505 }
506 }
508
509impl SparseTensor {
510 pub fn new(
511 rows: usize,
512 cols: usize,
513 col_ptrs: Vec<usize>,
514 row_indices: Vec<usize>,
515 values: Vec<f64>,
516 ) -> Result<Self, String> {
517 if col_ptrs.len() != cols.saturating_add(1) {
518 return Err(format!(
519 "SparseTensor col_ptrs length {} doesn't match cols {}",
520 col_ptrs.len(),
521 cols
522 ));
523 }
524 if row_indices.len() != values.len() {
525 return Err(format!(
526 "SparseTensor row index length {} doesn't match value length {}",
527 row_indices.len(),
528 values.len()
529 ));
530 }
531 if col_ptrs.first().copied().unwrap_or(usize::MAX) != 0 {
532 return Err("SparseTensor col_ptrs must start at 0".to_string());
533 }
534 if col_ptrs.last().copied().unwrap_or(usize::MAX) != values.len() {
535 return Err("SparseTensor final col_ptr must equal nnz".to_string());
536 }
537 for window in col_ptrs.windows(2) {
538 if window[0] > window[1] {
539 return Err("SparseTensor col_ptrs must be nondecreasing".to_string());
540 }
541 }
542 for col in 0..cols {
543 let start = col_ptrs[col];
544 let end = col_ptrs[col + 1];
545 let mut prev: Option<usize> = None;
546 for &row in &row_indices[start..end] {
547 if row >= rows {
548 return Err(format!("SparseTensor row index {row} exceeds rows {rows}"));
549 }
550 if prev.is_some_and(|p| p >= row) {
551 return Err("SparseTensor row indices must be sorted and unique".to_string());
552 }
553 prev = Some(row);
554 }
555 }
556 Ok(Self {
557 rows,
558 cols,
559 col_ptrs,
560 row_indices,
561 values,
562 })
563 }
564
565 pub fn zeros(rows: usize, cols: usize) -> Self {
566 Self {
567 rows,
568 cols,
569 col_ptrs: vec![0; cols.saturating_add(1)],
570 row_indices: Vec::new(),
571 values: Vec::new(),
572 }
573 }
574
575 pub fn nnz(&self) -> usize {
576 self.values.len()
577 }
578
579 pub fn shape(&self) -> Vec<usize> {
580 vec![self.rows, self.cols]
581 }
582
583 pub fn to_dense(&self) -> Result<Tensor, String> {
584 let len = self
585 .rows
586 .checked_mul(self.cols)
587 .ok_or_else(|| "SparseTensor dense dimensions overflow usize".to_string())?;
588 let mut data = Vec::new();
589 data.try_reserve_exact(len)
590 .map_err(|err| format!("SparseTensor dense allocation failed: {err}"))?;
591 data.resize(len, 0.0);
592 for col in 0..self.cols {
593 for idx in self.col_ptrs[col]..self.col_ptrs[col + 1] {
594 let row = self.row_indices[idx];
595 data[row + col * self.rows] = self.values[idx];
596 }
597 }
598 Tensor::new(data, self.shape())
599 }
600
601 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
602 if row >= self.rows || col >= self.cols {
603 return None;
604 }
605 let start = self.col_ptrs[col];
606 let end = self.col_ptrs[col + 1];
607 self.row_indices[start..end]
608 .binary_search(&row)
609 .ok()
610 .map(|offset| self.values[start + offset])
611 }
612}
613
614#[cfg(test)]
615mod sparse_tensor_tests {
616 use super::*;
617
618 #[test]
619 fn to_dense_rejects_overflowing_dimensions() {
620 let sparse = SparseTensor {
621 rows: usize::MAX,
622 cols: 2,
623 col_ptrs: vec![0, 0, 0],
624 row_indices: Vec::new(),
625 values: Vec::new(),
626 };
627
628 let err = sparse.to_dense().unwrap_err();
629 assert!(err.contains("overflow"));
630 }
631}
632
633impl ComplexTensor {
634 pub fn new(data: Vec<(f64, f64)>, shape: Vec<usize>) -> Result<Self, String> {
635 let expected: usize = shape.iter().product();
636 if data.len() != expected {
637 return Err(format!(
638 "ComplexTensor data length {} doesn't match shape {:?} ({} elements)",
639 data.len(),
640 shape,
641 expected
642 ));
643 }
644 let (rows, cols) = if shape.len() >= 2 {
645 (shape[0], shape[1])
646 } else if shape.len() == 1 {
647 (1, shape[0])
648 } else {
649 (0, 0)
650 };
651 Ok(ComplexTensor {
652 data,
653 shape,
654 rows,
655 cols,
656 })
657 }
658 pub fn new_2d(data: Vec<(f64, f64)>, rows: usize, cols: usize) -> Result<Self, String> {
659 Self::new(data, vec![rows, cols])
660 }
661 pub fn zeros(shape: Vec<usize>) -> Self {
662 let size: usize = shape.iter().product();
663 let (rows, cols) = if shape.len() >= 2 {
664 (shape[0], shape[1])
665 } else if shape.len() == 1 {
666 (1, shape[0])
667 } else {
668 (0, 0)
669 };
670 ComplexTensor {
671 data: vec![(0.0, 0.0); size],
672 shape,
673 rows,
674 cols,
675 }
676 }
677}
678
679const MAX_ND_DISPLAY_ELEMENTS: usize = 4096;
680
681fn should_expand_nd_display(shape: &[usize]) -> bool {
682 shape.len() > 2
683 && matches!(
684 total_len(shape),
685 Some(total) if total > 0 && total <= MAX_ND_DISPLAY_ELEMENTS
686 )
687}
688
689fn column_major_strides(shape: &[usize]) -> Vec<usize> {
690 let mut strides = Vec::with_capacity(shape.len());
691 let mut stride = 1usize;
692 for &dim in shape {
693 strides.push(stride);
694 stride = stride.saturating_mul(dim);
695 }
696 strides
697}
698
699fn decode_page_coords(mut page_index: usize, page_shape: &[usize]) -> Vec<usize> {
700 let mut coords = Vec::with_capacity(page_shape.len());
701 for &dim in page_shape {
702 if dim == 0 {
703 coords.push(0);
704 } else {
705 coords.push(page_index % dim);
706 page_index /= dim;
707 }
708 }
709 coords
710}
711
712fn write_nd_pages(
713 f: &mut fmt::Formatter<'_>,
714 shape: &[usize],
715 mut write_element: impl FnMut(&mut fmt::Formatter<'_>, usize) -> fmt::Result,
716) -> fmt::Result {
717 if shape.len() <= 2 {
718 return Ok(());
719 }
720 let rows = shape[0];
721 let cols = shape[1];
722 if rows == 0 || cols == 0 {
723 return write!(f, "[]");
724 }
725 let Some(page_count) = total_len(&shape[2..]) else {
726 return write!(f, "Tensor(shape={shape:?})");
727 };
728 if page_count == 0 {
729 return write!(f, "[]");
730 }
731 let strides = column_major_strides(shape);
732 for page_index in 0..page_count {
733 if page_index > 0 {
734 write!(f, "\n\n")?;
735 }
736 let coords = decode_page_coords(page_index, &shape[2..]);
737 write!(f, "(:, :")?;
738 for &coord in &coords {
739 write!(f, ", {}", coord + 1)?;
740 }
741 write!(f, ") =")?;
742
743 let mut page_base = 0usize;
744 for (offset, &coord) in coords.iter().enumerate() {
745 page_base += coord * strides[offset + 2];
746 }
747 for r in 0..rows {
748 writeln!(f)?;
749 write!(f, " ")?;
750 for c in 0..cols {
751 if c > 0 {
752 write!(f, " ")?;
753 }
754 let linear = page_base + r + c * rows;
755 write_element(f, linear)?;
756 }
757 }
758 }
759 Ok(())
760}
761
762impl fmt::Display for Tensor {
763 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
764 match self.shape.len() {
765 0 | 1 => {
766 write!(f, "[")?;
768 for (i, v) in self.data.iter().enumerate() {
769 if i > 0 {
770 write!(f, " ")?;
771 }
772 write!(f, "{}", format_number(*v))?;
773 }
774 write!(f, "]")
775 }
776 2 => {
777 let rows = self.rows();
778 let cols = self.cols();
779 for r in 0..rows {
781 writeln!(f)?;
782 write!(f, " ")?; for c in 0..cols {
784 if c > 0 {
785 write!(f, " ")?;
786 }
787 let v = self.data[r + c * rows];
788 write!(f, "{}", format_number(v))?;
789 }
790 }
791 Ok(())
792 }
793 _ => {
794 if should_expand_nd_display(&self.shape) {
795 write_nd_pages(f, &self.shape, |f, idx| {
796 write!(f, "{}", format_number(self.data[idx]))
797 })
798 } else {
799 write!(f, "Tensor(shape={:?})", self.shape)
800 }
801 }
802 }
803 }
804}
805
806impl fmt::Display for SparseTensor {
807 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
808 writeln!(
809 f,
810 "{}x{} sparse double matrix with {} nonzero entries",
811 self.rows,
812 self.cols,
813 self.nnz()
814 )?;
815 if self.nnz() == 0 {
816 return Ok(());
817 }
818 for col in 0..self.cols {
819 for idx in self.col_ptrs[col]..self.col_ptrs[col + 1] {
820 let row = self.row_indices[idx];
821 writeln!(
822 f,
823 " ({},{}) {}",
824 row + 1,
825 col + 1,
826 format_number(self.values[idx])
827 )?;
828 }
829 }
830 Ok(())
831 }
832}
833
834impl fmt::Display for StringArray {
835 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
836 let (rows, cols) = match self.shape.len() {
837 0 => (0, 0),
838 1 => (1, self.shape[0]),
839 _ => (self.shape[0], self.shape[1]),
840 };
841 let count = self.data.len();
842 if count == 1 && rows == 1 && cols == 1 {
843 let v = &self.data[0];
844 if v == "<missing>" {
845 return write!(f, "<missing>");
846 }
847 let escaped = v.replace('"', "\\\"");
848 return write!(f, "\"{escaped}\"");
849 }
850 if self.shape.len() > 2 {
851 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
852 return write!(f, "{} string array", dims.join("x"));
853 }
854 write!(f, "{rows}x{cols} string array")?;
855 if rows == 0 || cols == 0 {
856 return Ok(());
857 }
858 for r in 0..rows {
859 writeln!(f)?;
860 write!(f, " ")?;
861 for c in 0..cols {
862 if c > 0 {
863 write!(f, " ")?;
864 }
865 let v = &self.data[r + c * rows];
866 if v == "<missing>" {
867 write!(f, "<missing>")?;
868 } else {
869 let escaped = v.replace('"', "\\\"");
870 write!(f, "\"{escaped}\"")?;
871 }
872 }
873 }
874 Ok(())
875 }
876}
877
878impl fmt::Display for LogicalArray {
879 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
880 if self.data.len() == 1 {
881 return write!(f, "{}", if self.data[0] != 0 { 1 } else { 0 });
882 }
883 match self.shape.len() {
884 0 => write!(f, "[]"),
885 1 => {
886 write!(f, "[")?;
887 for (i, v) in self.data.iter().enumerate() {
888 if i > 0 {
889 write!(f, " ")?;
890 }
891 write!(f, "{}", if *v != 0 { 1 } else { 0 })?;
892 }
893 write!(f, "]")
894 }
895 2 => {
896 let rows = self.shape[0];
897 let cols = self.shape[1];
898 for r in 0..rows {
900 writeln!(f)?;
901 write!(f, " ")?; for c in 0..cols {
903 if c > 0 {
904 write!(f, " ")?;
905 }
906 let idx = r + c * rows;
907 write!(f, "{}", if self.data[idx] != 0 { 1 } else { 0 })?;
908 }
909 }
910 Ok(())
911 }
912 _ => {
913 if should_expand_nd_display(&self.shape) {
914 write_nd_pages(f, &self.shape, |f, idx| {
915 write!(f, "{}", if self.data[idx] != 0 { 1 } else { 0 })
916 })
917 } else {
918 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
919 write!(f, "{} logical array", dims.join("x"))
920 }
921 }
922 }
923 }
924}
925
926impl fmt::Display for CharArray {
927 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
928 for r in 0..self.rows {
929 writeln!(f)?;
930 write!(f, " ")?; for c in 0..self.cols {
932 let ch = self.data[r * self.cols + c];
933 write!(f, "{ch}")?;
934 }
935 }
936 Ok(())
937 }
938}
939
940impl From<i32> for Value {
942 fn from(i: i32) -> Self {
943 Value::Int(IntValue::I32(i))
944 }
945}
946impl From<i64> for Value {
947 fn from(i: i64) -> Self {
948 Value::Int(IntValue::I64(i))
949 }
950}
951impl From<u32> for Value {
952 fn from(i: u32) -> Self {
953 Value::Int(IntValue::U32(i))
954 }
955}
956impl From<u64> for Value {
957 fn from(i: u64) -> Self {
958 Value::Int(IntValue::U64(i))
959 }
960}
961impl From<i16> for Value {
962 fn from(i: i16) -> Self {
963 Value::Int(IntValue::I16(i))
964 }
965}
966impl From<i8> for Value {
967 fn from(i: i8) -> Self {
968 Value::Int(IntValue::I8(i))
969 }
970}
971impl From<u16> for Value {
972 fn from(i: u16) -> Self {
973 Value::Int(IntValue::U16(i))
974 }
975}
976impl From<u8> for Value {
977 fn from(i: u8) -> Self {
978 Value::Int(IntValue::U8(i))
979 }
980}
981
982impl From<f64> for Value {
983 fn from(f: f64) -> Self {
984 Value::Num(f)
985 }
986}
987
988impl From<bool> for Value {
989 fn from(b: bool) -> Self {
990 Value::Bool(b)
991 }
992}
993
994impl From<String> for Value {
995 fn from(s: String) -> Self {
996 Value::String(s)
997 }
998}
999
1000impl From<&str> for Value {
1001 fn from(s: &str) -> Self {
1002 Value::String(s.to_string())
1003 }
1004}
1005
1006impl From<Tensor> for Value {
1007 fn from(m: Tensor) -> Self {
1008 Value::Tensor(m)
1009 }
1010}
1011
1012impl TryFrom<&Value> for i32 {
1016 type Error = String;
1017 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1018 match v {
1019 Value::Int(i) => Ok(i.to_i64() as i32),
1020 Value::Num(n) => Ok(*n as i32),
1021 _ => Err(format!("cannot convert {v:?} to i32")),
1022 }
1023 }
1024}
1025
1026impl TryFrom<&Value> for f64 {
1027 type Error = String;
1028 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1029 match v {
1030 Value::Num(n) => Ok(*n),
1031 Value::Int(i) => Ok(i.to_f64()),
1032 _ => Err(format!("cannot convert {v:?} to f64")),
1033 }
1034 }
1035}
1036
1037impl TryFrom<&Value> for bool {
1038 type Error = String;
1039 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1040 match v {
1041 Value::Bool(b) => Ok(*b),
1042 Value::Int(i) => Ok(!i.is_zero()),
1043 Value::Num(n) => Ok(*n != 0.0),
1044 _ => Err(format!("cannot convert {v:?} to bool")),
1045 }
1046 }
1047}
1048
1049impl TryFrom<&Value> for String {
1050 type Error = String;
1051 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1052 match v {
1053 Value::String(s) => Ok(s.clone()),
1054 Value::StringArray(sa) => {
1055 if sa.data.len() == 1 {
1056 Ok(sa.data[0].clone())
1057 } else {
1058 Err("cannot convert string array to scalar string".to_string())
1059 }
1060 }
1061 Value::CharArray(ca) => {
1062 if ca.rows == 1 {
1064 Ok(ca.data.iter().collect())
1065 } else {
1066 Err("cannot convert multi-row char array to scalar string".to_string())
1067 }
1068 }
1069 Value::Int(i) => Ok(i.to_i64().to_string()),
1070 Value::Num(n) => Ok(n.to_string()),
1071 Value::Bool(b) => Ok(b.to_string()),
1072 _ => Err(format!("cannot convert {v:?} to String")),
1073 }
1074 }
1075}
1076
1077impl TryFrom<&Value> for Tensor {
1078 type Error = String;
1079 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1080 match v {
1081 Value::Tensor(m) => Ok(m.clone()),
1082 _ => Err(format!("cannot convert {v:?} to Tensor")),
1083 }
1084 }
1085}
1086
1087impl TryFrom<&Value> for Value {
1088 type Error = String;
1089 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1090 Ok(v.clone())
1091 }
1092}
1093
1094impl TryFrom<&Value> for Vec<Value> {
1095 type Error = String;
1096 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1097 match v {
1098 Value::Cell(c) => Ok(c.data.iter().map(|p| (**p).clone()).collect()),
1099 _ => Err(format!("cannot convert {v:?} to Vec<Value>")),
1100 }
1101 }
1102}
1103
1104use serde::{Deserialize, Serialize};
1105
1106#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
1109pub enum Type {
1110 Int,
1112 Num,
1114 Bool,
1116 Logical {
1118 shape: Option<Vec<Option<usize>>>,
1120 },
1121 String,
1123 Tensor {
1125 shape: Option<Vec<Option<usize>>>,
1127 },
1128 Cell {
1130 element_type: Option<Box<Type>>,
1132 length: Option<usize>,
1134 },
1135 Function {
1137 params: Vec<Type>,
1139 returns: Box<Type>,
1141 },
1142 Void,
1144 Unknown,
1146 Union(Vec<Type>),
1148 Struct {
1150 known_fields: Option<Vec<String>>, },
1153 OutputList(Vec<Type>),
1155}
1156
1157impl Type {
1158 pub fn tensor() -> Self {
1160 Type::Tensor { shape: None }
1161 }
1162
1163 pub fn logical() -> Self {
1165 Type::Logical { shape: None }
1166 }
1167
1168 pub fn logical_with_shape(shape: Vec<usize>) -> Self {
1170 Type::Logical {
1171 shape: Some(shape.into_iter().map(Some).collect()),
1172 }
1173 }
1174
1175 pub fn tensor_with_shape(shape: Vec<usize>) -> Self {
1177 Type::Tensor {
1178 shape: Some(shape.into_iter().map(Some).collect()),
1179 }
1180 }
1181
1182 pub fn cell() -> Self {
1184 Type::Cell {
1185 element_type: None,
1186 length: None,
1187 }
1188 }
1189
1190 pub fn cell_of(element_type: Type) -> Self {
1192 Type::Cell {
1193 element_type: Some(Box::new(element_type)),
1194 length: None,
1195 }
1196 }
1197
1198 pub fn is_compatible_with(&self, other: &Type) -> bool {
1200 match (self, other) {
1201 (Type::Unknown, _) | (_, Type::Unknown) => true,
1202 (Type::Int, Type::Num) | (Type::Num, Type::Int) => true, (Type::Tensor { .. }, Type::Tensor { .. }) => true, (Type::OutputList(a), Type::OutputList(b)) => a.len() == b.len(),
1205 (a, b) => a == b,
1206 }
1207 }
1208
1209 pub fn unify(&self, other: &Type) -> Type {
1211 match (self, other) {
1212 (Type::Unknown, t) | (t, Type::Unknown) => t.clone(),
1213 (Type::Int, Type::Num) | (Type::Num, Type::Int) => Type::Num,
1214 (Type::Tensor { shape: a }, Type::Tensor { shape: b }) => {
1215 let a_norm = match a {
1216 Some(dims) if dims.is_empty() => None,
1217 _ => a.clone(),
1218 };
1219 let b_norm = match b {
1220 Some(dims) if dims.is_empty() => None,
1221 _ => b.clone(),
1222 };
1223 let a_unknown = a_norm
1224 .as_ref()
1225 .map(|dims| dims.iter().all(|d| d.is_none()))
1226 .unwrap_or(true);
1227 let b_unknown = b_norm
1228 .as_ref()
1229 .map(|dims| dims.iter().all(|d| d.is_none()))
1230 .unwrap_or(true);
1231 if a_norm == b_norm
1232 || (!a_unknown && b_unknown)
1233 || (a_norm.is_some() && b_norm.is_none())
1234 {
1235 Type::Tensor { shape: a_norm }
1236 } else if (a_unknown && !b_unknown) || (a_norm.is_none() && b_norm.is_some()) {
1237 Type::Tensor { shape: b_norm }
1238 } else {
1239 Type::tensor()
1240 }
1241 }
1242 (Type::Logical { shape: a }, Type::Logical { shape: b }) => {
1243 let a_norm = match a {
1244 Some(dims) if dims.is_empty() => None,
1245 _ => a.clone(),
1246 };
1247 let b_norm = match b {
1248 Some(dims) if dims.is_empty() => None,
1249 _ => b.clone(),
1250 };
1251 let a_unknown = a_norm
1252 .as_ref()
1253 .map(|dims| dims.iter().all(|d| d.is_none()))
1254 .unwrap_or(true);
1255 let b_unknown = b_norm
1256 .as_ref()
1257 .map(|dims| dims.iter().all(|d| d.is_none()))
1258 .unwrap_or(true);
1259 if a_norm == b_norm
1260 || (!a_unknown && b_unknown)
1261 || (a_norm.is_some() && b_norm.is_none())
1262 {
1263 Type::Logical { shape: a_norm }
1264 } else if (a_unknown && !b_unknown) || (a_norm.is_none() && b_norm.is_some()) {
1265 Type::Logical { shape: b_norm }
1266 } else {
1267 Type::logical()
1268 }
1269 }
1270 (Type::Struct { known_fields: a }, Type::Struct { known_fields: b }) => match (a, b) {
1271 (None, None) => Type::Struct { known_fields: None },
1272 (Some(ka), None) | (None, Some(ka)) => Type::Struct {
1273 known_fields: Some(ka.clone()),
1274 },
1275 (Some(ka), Some(kb)) => {
1276 let mut set: std::collections::BTreeSet<String> = ka.iter().cloned().collect();
1277 set.extend(kb.iter().cloned());
1278 Type::Struct {
1279 known_fields: Some(set.into_iter().collect()),
1280 }
1281 }
1282 },
1283 (Type::OutputList(a), Type::OutputList(b)) => {
1284 if a.len() == b.len() {
1285 let items = a
1286 .iter()
1287 .zip(b.iter())
1288 .map(|(lhs, rhs)| lhs.unify(rhs))
1289 .collect();
1290 Type::OutputList(items)
1291 } else {
1292 Type::OutputList(vec![Type::Unknown; a.len().max(b.len())])
1293 }
1294 }
1295 (a, b) if a == b => a.clone(),
1296 _ => Type::Union(vec![self.clone(), other.clone()]),
1297 }
1298 }
1299
1300 pub fn from_value(value: &Value) -> Type {
1302 match value {
1303 Value::Int(_) => Type::Int,
1304 Value::Num(_) => Type::Num,
1305 Value::Complex(_, _) => Type::Num, Value::Bool(_) => Type::Bool,
1307 Value::LogicalArray(arr) => Type::Logical {
1308 shape: Some(arr.shape.iter().map(|&d| Some(d)).collect()),
1309 },
1310 Value::String(_) => Type::String,
1311 Value::StringArray(_sa) => {
1312 Type::cell_of(Type::String)
1314 }
1315 Value::Tensor(t) => Type::Tensor {
1316 shape: Some(t.shape.iter().map(|&d| Some(d)).collect()),
1317 },
1318 Value::SparseTensor(t) => Type::Tensor {
1319 shape: Some(vec![Some(t.rows), Some(t.cols)]),
1320 },
1321 Value::ComplexTensor(t) => Type::Tensor {
1322 shape: Some(t.shape.iter().map(|&d| Some(d)).collect()),
1323 },
1324 Value::Cell(cells) => {
1325 if cells.data.is_empty() {
1326 Type::cell()
1327 } else {
1328 let element_type = Type::from_value(&cells.data[0]);
1330 Type::Cell {
1331 element_type: Some(Box::new(element_type)),
1332 length: Some(cells.data.len()),
1333 }
1334 }
1335 }
1336 Value::GpuTensor(h) => Type::Tensor {
1337 shape: Some(h.shape.iter().map(|&d| Some(d)).collect()),
1338 },
1339 Value::Object(_) => Type::Unknown,
1340 Value::HandleObject(_) => Type::Unknown,
1341 Value::Listener(_) => Type::Unknown,
1342 Value::Struct(_) => Type::Struct { known_fields: None },
1343 Value::FunctionHandle(_)
1344 | Value::ExternalFunctionHandle(_)
1345 | Value::MethodFunctionHandle(_)
1346 | Value::BoundFunctionHandle { .. } => Type::Function {
1347 params: vec![Type::Unknown],
1348 returns: Box::new(Type::Unknown),
1349 },
1350 Value::Closure(_) => Type::Function {
1351 params: vec![Type::Unknown],
1352 returns: Box::new(Type::Unknown),
1353 },
1354 Value::ClassRef(_) => Type::Unknown,
1355 Value::MException(_) => Type::Unknown,
1356 Value::CharArray(ca) => {
1357 Type::Cell {
1359 element_type: Some(Box::new(Type::String)),
1360 length: Some(ca.rows * ca.cols),
1361 }
1362 }
1363 Value::OutputList(values) => {
1364 Type::OutputList(values.iter().map(Type::from_value).collect())
1365 }
1366 }
1367 }
1368}
1369
1370#[derive(Debug, Clone, PartialEq)]
1371pub struct Closure {
1372 pub function_name: String,
1373 pub bound_function: Option<usize>,
1374 pub captures: Vec<Value>,
1375}
1376
1377#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1379pub enum AccelTag {
1380 Unary,
1381 Elementwise,
1382 Reduction,
1383 MatMul,
1384 Transpose,
1385 ArrayConstruct,
1386}
1387
1388pub type BuiltinControlFlow = runmat_async::RuntimeError;
1390
1391pub type BuiltinFuture = Pin<Box<dyn Future<Output = Result<Value, BuiltinControlFlow>> + 'static>>;
1393
1394#[derive(Clone, Debug, Default)]
1395pub struct ResolveContext {
1396 pub literal_args: Vec<LiteralValue>,
1397}
1398
1399#[derive(Clone, Debug, PartialEq)]
1400pub enum LiteralValue {
1401 Number(f64),
1402 Bool(bool),
1403 String(String),
1404 Vector(Vec<LiteralValue>),
1405 Unknown,
1406}
1407
1408impl ResolveContext {
1409 pub fn new(literal_args: Vec<LiteralValue>) -> Self {
1410 Self { literal_args }
1411 }
1412
1413 pub fn numeric_dims(&self) -> Vec<Option<usize>> {
1414 self.numeric_dims_from(0)
1415 }
1416
1417 pub fn numeric_dims_from(&self, start: usize) -> Vec<Option<usize>> {
1418 let slice = self.literal_args.get(start..).unwrap_or(&[]);
1419 if let Some(LiteralValue::Vector(values)) = slice.first() {
1420 return values
1421 .iter()
1422 .map(Self::numeric_dimension_from_literal)
1423 .collect();
1424 }
1425 slice
1426 .iter()
1427 .map(Self::numeric_dimension_from_literal)
1428 .collect()
1429 }
1430
1431 pub fn literal_string_at(&self, index: usize) -> Option<String> {
1432 match self.literal_args.get(index) {
1433 Some(LiteralValue::String(value)) => Some(value.to_ascii_lowercase()),
1434 _ => None,
1435 }
1436 }
1437
1438 pub fn literal_bool_at(&self, index: usize) -> Option<bool> {
1439 match self.literal_args.get(index) {
1440 Some(LiteralValue::Bool(value)) => Some(*value),
1441 _ => None,
1442 }
1443 }
1444
1445 pub fn literal_vector_at(&self, index: usize) -> Option<Vec<LiteralValue>> {
1446 match self.literal_args.get(index) {
1447 Some(LiteralValue::Vector(values)) => Some(values.clone()),
1448 _ => None,
1449 }
1450 }
1451
1452 pub fn numeric_vector_at(&self, index: usize) -> Option<Vec<Option<usize>>> {
1453 let values = match self.literal_args.get(index) {
1454 Some(LiteralValue::Vector(values)) => values,
1455 _ => return None,
1456 };
1457 if values
1458 .iter()
1459 .any(|value| matches!(value, LiteralValue::Vector(_)))
1460 {
1461 return None;
1462 }
1463 Some(
1464 values
1465 .iter()
1466 .map(Self::numeric_dimension_from_literal)
1467 .collect(),
1468 )
1469 }
1470
1471 fn numeric_dimension_from_literal(value: &LiteralValue) -> Option<usize> {
1472 match value {
1473 LiteralValue::Number(num) => {
1474 if num.is_finite() {
1475 let rounded = num.round();
1476 if (num - rounded).abs() <= 1e-9 && rounded >= 0.0 {
1477 return Some(rounded as usize);
1478 }
1479 }
1480 None
1481 }
1482 _ => None,
1483 }
1484 }
1485}
1486
1487#[cfg(test)]
1488mod resolve_context_tests {
1489 use super::{LiteralValue, ResolveContext};
1490
1491 #[test]
1492 fn numeric_dims_reads_vector_literal() {
1493 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
1494 LiteralValue::Number(2.0),
1495 LiteralValue::Number(3.0),
1496 ])]);
1497 assert_eq!(ctx.numeric_dims(), vec![Some(2), Some(3)]);
1498 }
1499
1500 #[test]
1501 fn numeric_dims_skips_non_numeric_entries() {
1502 let ctx = ResolveContext::new(vec![
1503 LiteralValue::Number(4.0),
1504 LiteralValue::String("like".to_string()),
1505 LiteralValue::Unknown,
1506 ]);
1507 assert_eq!(ctx.numeric_dims(), vec![Some(4), None, None]);
1508 }
1509
1510 #[test]
1511 fn numeric_dims_prefers_vector_even_with_trailing_args() {
1512 let ctx = ResolveContext::new(vec![
1513 LiteralValue::Vector(vec![LiteralValue::Number(1.0), LiteralValue::Number(5.0)]),
1514 LiteralValue::String("like".to_string()),
1515 ]);
1516 assert_eq!(ctx.numeric_dims(), vec![Some(1), Some(5)]);
1517 }
1518
1519 #[test]
1520 fn literal_string_is_lowercased() {
1521 let ctx = ResolveContext::new(vec![LiteralValue::String("OmItNaN".to_string())]);
1522 assert_eq!(ctx.literal_string_at(0), Some("omitnan".to_string()));
1523 }
1524
1525 #[test]
1526 fn literal_bool_is_available() {
1527 let ctx = ResolveContext::new(vec![LiteralValue::Bool(true)]);
1528 assert_eq!(ctx.literal_bool_at(0), Some(true));
1529 }
1530
1531 #[test]
1532 fn literal_vector_at_returns_clone() {
1533 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
1534 LiteralValue::Number(7.0),
1535 LiteralValue::Unknown,
1536 ])]);
1537 assert_eq!(
1538 ctx.literal_vector_at(0),
1539 Some(vec![LiteralValue::Number(7.0), LiteralValue::Unknown])
1540 );
1541 }
1542
1543 #[test]
1544 fn numeric_vector_at_rejects_nested_vectors() {
1545 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![LiteralValue::Vector(
1546 vec![LiteralValue::Number(1.0)],
1547 )])]);
1548 assert_eq!(ctx.numeric_vector_at(0), None);
1549 }
1550}
1551
1552pub type TypeResolver = fn(args: &[Type]) -> Type;
1553pub type TypeResolverWithContext = fn(args: &[Type], ctx: &ResolveContext) -> Type;
1554
1555#[derive(Clone, Copy, Debug)]
1556pub enum TypeResolverKind {
1557 Simple(TypeResolver),
1558 WithContext(TypeResolverWithContext),
1559}
1560
1561pub fn type_resolver_kind(resolver: TypeResolver) -> TypeResolverKind {
1562 TypeResolverKind::Simple(resolver)
1563}
1564
1565pub fn type_resolver_kind_ctx(resolver: TypeResolverWithContext) -> TypeResolverKind {
1566 TypeResolverKind::WithContext(resolver)
1567}
1568
1569#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1570pub enum BuiltinOutputMode {
1571 Fixed,
1572 ByRequestedOutputCount,
1573}
1574
1575#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1576pub enum BuiltinCompletionPolicy {
1577 Public,
1578 MethodOnly,
1579 HiddenInternal,
1580}
1581
1582#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1583pub enum BuiltinParamArity {
1584 Required,
1585 Optional,
1586 Variadic,
1587}
1588
1589#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1590pub enum BuiltinParamType {
1591 Any,
1592 NumericScalar,
1593 IntegerScalar,
1594 StringScalar,
1595 NumericArray,
1596 LogicalArray,
1597 SizeArg,
1598 LikePrototype,
1599 AxesHandle,
1600 StyleSpec,
1601 PropertyName,
1602 PropertyValue,
1603}
1604
1605#[derive(Debug, Clone, Serialize)]
1606pub struct BuiltinParamDescriptor {
1607 pub name: &'static str,
1608 pub ty: BuiltinParamType,
1609 pub arity: BuiltinParamArity,
1610 pub default: Option<&'static str>,
1611 pub description: &'static str,
1612}
1613
1614#[derive(Debug, Clone, Serialize)]
1615pub struct BuiltinSignatureDescriptor {
1616 pub label: &'static str,
1617 pub inputs: &'static [BuiltinParamDescriptor],
1618 pub outputs: &'static [BuiltinParamDescriptor],
1619}
1620
1621#[derive(Debug, Clone, Serialize)]
1622pub struct BuiltinErrorDescriptor {
1623 pub code: &'static str,
1624 pub identifier: Option<&'static str>,
1625 pub when: &'static str,
1626 pub message: &'static str,
1627}
1628
1629#[derive(Debug, Clone, Serialize)]
1630pub struct BuiltinDescriptor {
1631 pub signatures: &'static [BuiltinSignatureDescriptor],
1632 pub output_mode: BuiltinOutputMode,
1633 pub completion_policy: BuiltinCompletionPolicy,
1634 pub errors: &'static [BuiltinErrorDescriptor],
1635}
1636
1637#[derive(Debug, Clone)]
1639pub struct BuiltinFunction {
1640 pub name: &'static str,
1641 pub description: &'static str,
1642 pub category: &'static str,
1643 pub doc: &'static str,
1644 pub examples: &'static str,
1645 pub param_types: Vec<Type>,
1646 pub return_type: Type,
1647 pub type_resolver: Option<TypeResolverKind>,
1648 pub implementation: fn(&[Value]) -> BuiltinFuture,
1649 pub accel_tags: &'static [AccelTag],
1650 pub is_sink: bool,
1651 pub suppress_auto_output: bool,
1652 pub descriptor: Option<&'static BuiltinDescriptor>,
1653}
1654
1655impl BuiltinFunction {
1656 #[allow(clippy::too_many_arguments)]
1657 pub fn new(
1658 name: &'static str,
1659 description: &'static str,
1660 category: &'static str,
1661 doc: &'static str,
1662 examples: &'static str,
1663 param_types: Vec<Type>,
1664 return_type: Type,
1665 type_resolver: Option<TypeResolverKind>,
1666 implementation: fn(&[Value]) -> BuiltinFuture,
1667 accel_tags: &'static [AccelTag],
1668 is_sink: bool,
1669 suppress_auto_output: bool,
1670 ) -> Self {
1671 Self {
1672 name,
1673 description,
1674 category,
1675 doc,
1676 examples,
1677 param_types,
1678 return_type,
1679 type_resolver,
1680 implementation,
1681 accel_tags,
1682 is_sink,
1683 suppress_auto_output,
1684 descriptor: None,
1685 }
1686 }
1687
1688 pub fn with_descriptor(mut self, descriptor: &'static BuiltinDescriptor) -> Self {
1689 self.descriptor = Some(descriptor);
1690 self
1691 }
1692
1693 pub fn with_descriptor_option(
1694 mut self,
1695 descriptor: Option<&'static BuiltinDescriptor>,
1696 ) -> Self {
1697 self.descriptor = descriptor;
1698 self
1699 }
1700
1701 pub fn infer_return_type(&self, args: &[Type]) -> Type {
1702 self.infer_return_type_with_context(args, &ResolveContext::default())
1703 }
1704
1705 pub fn infer_return_type_with_context(&self, args: &[Type], ctx: &ResolveContext) -> Type {
1706 if let Some(resolver) = self.type_resolver {
1707 return match resolver {
1708 TypeResolverKind::Simple(resolver) => resolver(args),
1709 TypeResolverKind::WithContext(resolver) => resolver(args, ctx),
1710 };
1711 }
1712 self.return_type.clone()
1713 }
1714
1715 pub fn semantics(&self) -> BuiltinSemantics {
1716 semantics::builtin_semantics_for(self)
1717 }
1718}
1719
1720#[derive(Clone)]
1722pub struct Constant {
1723 pub name: &'static str,
1724 pub value: Value,
1725}
1726
1727pub mod semantics;
1728pub mod shape_rules;
1729
1730pub use semantics::{
1731 builtin_semantics_for, builtin_semantics_for_name, BuiltinAsyncBehavior, BuiltinCompatibility,
1732 BuiltinEffects, BuiltinEnvironmentEffect, BuiltinPurity, BuiltinSemanticKind, BuiltinSemantics,
1733 BuiltinWorkspaceEffect, ConcatKind, ShapeTransformKind,
1734};
1735
1736impl std::fmt::Debug for Constant {
1737 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1738 write!(
1739 f,
1740 "Constant {{ name: {:?}, value: {:?} }}",
1741 self.name, self.value
1742 )
1743 }
1744}
1745
1746#[cfg(not(target_arch = "wasm32"))]
1747inventory::collect!(BuiltinFunction);
1748#[cfg(not(target_arch = "wasm32"))]
1749inventory::collect!(Constant);
1750
1751#[cfg(not(target_arch = "wasm32"))]
1752pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
1753 inventory::iter::<BuiltinFunction>().collect()
1754}
1755
1756#[cfg(target_arch = "wasm32")]
1757pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
1758 wasm_registry::builtin_functions()
1759}
1760
1761#[cfg(not(target_arch = "wasm32"))]
1762static BUILTIN_LOOKUP: OnceLock<HashMap<String, &'static BuiltinFunction>> = OnceLock::new();
1763
1764#[cfg(not(target_arch = "wasm32"))]
1765fn builtin_lookup_map() -> &'static HashMap<String, &'static BuiltinFunction> {
1766 BUILTIN_LOOKUP.get_or_init(|| {
1767 let mut map = HashMap::new();
1768 for func in builtin_functions() {
1769 map.insert(func.name.to_ascii_lowercase(), func);
1770 }
1771 map
1772 })
1773}
1774
1775#[cfg(not(target_arch = "wasm32"))]
1776pub fn builtin_function_by_name(name: &str) -> Option<&'static BuiltinFunction> {
1777 builtin_lookup_map()
1778 .get(&name.to_ascii_lowercase())
1779 .copied()
1780}
1781
1782#[cfg(target_arch = "wasm32")]
1783pub fn builtin_function_by_name(name: &str) -> Option<&'static BuiltinFunction> {
1784 wasm_registry::builtin_functions()
1785 .into_iter()
1786 .find(|f| f.name.eq_ignore_ascii_case(name))
1787}
1788
1789pub fn suppresses_auto_output(name: &str) -> bool {
1790 builtin_function_by_name(name)
1791 .map(|f| f.suppress_auto_output)
1792 .unwrap_or(false)
1793}
1794
1795#[cfg(not(target_arch = "wasm32"))]
1796pub fn constants() -> Vec<&'static Constant> {
1797 inventory::iter::<Constant>().collect()
1798}
1799
1800#[cfg(target_arch = "wasm32")]
1801pub fn constants() -> Vec<&'static Constant> {
1802 wasm_registry::constants()
1803}
1804
1805#[derive(Debug)]
1810pub struct BuiltinDoc {
1811 pub name: &'static str,
1812 pub category: Option<&'static str>,
1813 pub summary: Option<&'static str>,
1814 pub keywords: Option<&'static str>,
1815 pub errors: Option<&'static str>,
1816 pub related: Option<&'static str>,
1817 pub introduced: Option<&'static str>,
1818 pub status: Option<&'static str>,
1819 pub examples: Option<&'static str>,
1820}
1821
1822#[cfg(not(target_arch = "wasm32"))]
1823inventory::collect!(BuiltinDoc);
1824
1825#[cfg(not(target_arch = "wasm32"))]
1826pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
1827 inventory::iter::<BuiltinDoc>().collect()
1828}
1829
1830#[cfg(target_arch = "wasm32")]
1831pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
1832 wasm_registry::builtin_docs()
1833}
1834
1835#[derive(Debug, Clone, Copy, PartialEq, Default)]
1841pub enum FormatMode {
1842 #[default]
1844 Short,
1845 Long,
1847 ShortE,
1849 LongE,
1851 ShortG,
1853 LongG,
1855 Rational,
1857 Hex,
1859}
1860
1861runmat_thread_local! {
1862 static DISPLAY_FORMAT: RefCell<FormatMode> = const { RefCell::new(FormatMode::Short) };
1863}
1864
1865pub fn set_display_format(mode: FormatMode) {
1866 DISPLAY_FORMAT.with(|c| *c.borrow_mut() = mode);
1867}
1868
1869pub fn get_display_format() -> FormatMode {
1870 DISPLAY_FORMAT.with(|c| *c.borrow())
1871}
1872
1873pub fn format_number(value: f64) -> String {
1875 if value.is_nan() {
1876 return "NaN".to_string();
1877 }
1878 if value.is_infinite() {
1879 return if value.is_sign_negative() {
1880 "-Inf"
1881 } else {
1882 "Inf"
1883 }
1884 .to_string();
1885 }
1886 let mode = get_display_format();
1887 if mode == FormatMode::Hex {
1888 return fmt_hex(value);
1889 }
1890 let v = if value == 0.0 { 0.0 } else { value };
1891 match mode {
1892 FormatMode::Short => fmt_short(v),
1893 FormatMode::Long => fmt_long(v),
1894 FormatMode::ShortE => fmt_sci(v, 4),
1895 FormatMode::LongE => fmt_sci(v, 14),
1896 FormatMode::ShortG => fmt_compact(v, 5),
1897 FormatMode::LongG => fmt_compact(v, 15),
1898 FormatMode::Rational => fmt_rational(v),
1899 FormatMode::Hex => unreachable!("hex mode handled before zero normalization"),
1900 }
1901}
1902
1903fn matlab_exp(s: &str) -> String {
1905 if let Some(e_pos) = s.find('e') {
1906 let mantissa = &s[..e_pos];
1907 let exp: i32 = s[e_pos + 1..].parse().unwrap_or(0);
1908 let sign = if exp >= 0 { '+' } else { '-' };
1909 format!("{mantissa}e{sign}{:02}", exp.unsigned_abs())
1910 } else {
1911 s.to_string()
1912 }
1913}
1914
1915fn fmt_sci(v: f64, dec: usize) -> String {
1916 if v == 0.0 {
1917 return format!("0.{:0>dec$}e+00", 0, dec = dec);
1918 }
1919 let s = format!("{v:.dec$e}");
1920 matlab_exp(&s)
1921}
1922
1923fn fmt_short(v: f64) -> String {
1924 let abs = v.abs();
1925 if abs == 0.0 {
1926 return "0".to_string();
1927 }
1928 if v.fract() == 0.0 && abs < 1e15 {
1929 return format!("{}", v as i64);
1930 }
1931 if (0.001..10000.0).contains(&abs) {
1932 format!("{:.4}", v)
1933 } else {
1934 fmt_sci(v, 4)
1935 }
1936}
1937
1938fn fmt_long(v: f64) -> String {
1939 let abs = v.abs();
1940 if abs == 0.0 {
1941 return "0".to_string();
1942 }
1943 if v.fract() == 0.0 && abs < 1e15 {
1944 return format!("{}", v as i64);
1945 }
1946 if (0.001..10000.0).contains(&abs) {
1947 format!("{:.15}", v)
1948 } else {
1949 fmt_sci(v, 14)
1950 }
1951}
1952
1953fn fmt_compact(v: f64, sig_digits: usize) -> String {
1954 let abs = v.abs();
1955 if abs == 0.0 {
1956 return "0".to_string();
1957 }
1958 let use_scientific = !(1e-4..1e6).contains(&abs);
1959 if use_scientific {
1960 let dec = sig_digits - 1;
1961 let s = format!("{v:.dec$e}");
1962 if let Some(e_pos) = s.find('e') {
1964 let exp_part = &s[e_pos..];
1965 let mut mantissa = s[..e_pos].to_string();
1966 if let Some(dot) = mantissa.find('.') {
1967 let mut end = mantissa.len();
1968 while end > dot + 1 && mantissa.as_bytes()[end - 1] == b'0' {
1969 end -= 1;
1970 }
1971 if mantissa.as_bytes()[end - 1] == b'.' {
1972 end -= 1;
1973 }
1974 mantissa.truncate(end);
1975 }
1976 return matlab_exp(&format!("{mantissa}{exp_part}"));
1977 }
1978 return matlab_exp(&s);
1979 }
1980 let exp10 = abs.log10().floor() as i32;
1981 let decimals = ((sig_digits as i32 - 1 - exp10).max(0)) as usize;
1982 let pow = 10f64.powi(decimals as i32);
1983 let rounded = (v * pow).round() / pow;
1984 let mut s = format!("{rounded:.decimals$}");
1985 if let Some(dot) = s.find('.') {
1986 let mut end = s.len();
1987 while end > dot + 1 && s.as_bytes()[end - 1] == b'0' {
1988 end -= 1;
1989 }
1990 if s.as_bytes()[end - 1] == b'.' {
1991 end -= 1;
1992 }
1993 s.truncate(end);
1994 }
1995 if s.is_empty() || s == "-0" {
1996 s = "0".to_string();
1997 }
1998 s
1999}
2000
2001fn fmt_rational(v: f64) -> String {
2002 if v == 0.0 {
2003 return "0".to_string();
2004 }
2005 let negative = v < 0.0;
2006 let abs = v.abs();
2007 if v.fract() == 0.0 && abs < 1e15 {
2008 return format!("{}", v as i64);
2009 }
2010 let tol = 5e-7 * abs;
2013 let max_d = 1_000_000i64;
2014 let mut n0: i64 = 1;
2015 let mut n1: i64 = abs.floor() as i64;
2016 let mut d0: i64 = 0;
2017 let mut d1: i64 = 1;
2018 let mut a = abs;
2019 let mut best_n = n1;
2020 let mut best_d = d1;
2021 for _ in 0..50 {
2022 if (abs - best_n as f64 / best_d as f64).abs() <= tol {
2023 break;
2024 }
2025 let f = a.fract();
2026 if f < 1e-10 {
2027 break;
2028 }
2029 a = 1.0 / f;
2030 let q = a.floor() as i64;
2031 let Some(n2) = q.checked_mul(n1).and_then(|v| v.checked_add(n0)) else {
2032 break;
2033 };
2034 let Some(d2) = q.checked_mul(d1).and_then(|v| v.checked_add(d0)) else {
2035 break;
2036 };
2037 if d2 > max_d {
2038 break;
2039 }
2040 best_n = n2;
2041 best_d = d2;
2042 n0 = n1;
2043 n1 = n2;
2044 d0 = d1;
2045 d1 = d2;
2046 }
2047 let sign = if negative { "-" } else { "" };
2048 if best_d == 1 {
2049 format!("{sign}{best_n}")
2050 } else {
2051 format!("{sign}{best_n}/{best_d}")
2052 }
2053}
2054
2055fn fmt_hex(v: f64) -> String {
2056 format!("{:016x}", v.to_bits())
2057}
2058
2059#[derive(Debug, Clone, PartialEq)]
2061pub struct MException {
2062 pub identifier: String,
2063 pub message: String,
2064 pub stack: Vec<String>,
2065}
2066
2067impl MException {
2068 pub fn new(identifier: String, message: String) -> Self {
2069 Self {
2070 identifier,
2071 message,
2072 stack: Vec::new(),
2073 }
2074 }
2075}
2076
2077#[derive(Debug, Clone)]
2079pub struct HandleRef {
2080 pub class_name: String,
2081 pub target: GcPtr<Value>,
2082 pub valid: bool,
2083}
2084
2085const HANDLE_VALID_FLAG_PROPERTY: &str = "__runmat_handle_valid__";
2086
2087pub fn is_handle_valid(handle: &HandleRef) -> bool {
2088 if !handle.valid {
2089 return false;
2090 }
2091 let raw = unsafe { handle.target.as_raw() };
2092 if raw.is_null() {
2093 return false;
2094 }
2095 match unsafe { &*raw } {
2096 Value::Object(obj) => !matches!(
2097 obj.properties.get(HANDLE_VALID_FLAG_PROPERTY),
2098 Some(Value::Bool(false))
2099 ),
2100 _ => true,
2101 }
2102}
2103
2104pub fn set_handle_valid(handle: &HandleRef, valid: bool) -> bool {
2105 let raw = unsafe { handle.target.as_raw_mut() };
2106 if raw.is_null() {
2107 return false;
2108 }
2109 match unsafe { &mut *raw } {
2110 Value::Object(obj) => {
2111 obj.properties
2112 .insert(HANDLE_VALID_FLAG_PROPERTY.to_string(), Value::Bool(valid));
2113 true
2114 }
2115 _ => false,
2116 }
2117}
2118
2119impl PartialEq for HandleRef {
2120 fn eq(&self, other: &Self) -> bool {
2121 let a = unsafe { self.target.as_raw() } as usize;
2122 let b = unsafe { other.target.as_raw() } as usize;
2123 a == b
2124 }
2125}
2126
2127#[derive(Debug, Clone, PartialEq)]
2129pub struct Listener {
2130 pub id: u64,
2131 pub target: GcPtr<Value>,
2132 pub event_name: String,
2133 pub callback: GcPtr<Value>,
2134 pub enabled: bool,
2135 pub valid: bool,
2136}
2137
2138impl Listener {
2139 pub fn class_name(&self) -> String {
2140 match unsafe { &*self.target.as_raw() } {
2141 Value::Object(o) => o.class_name.clone(),
2142 Value::HandleObject(h) => h.class_name.clone(),
2143 _ => String::new(),
2144 }
2145 }
2146}
2147
2148impl fmt::Display for Value {
2149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2150 match self {
2151 Value::Int(i) => write!(f, "{}", i.to_i64()),
2152 Value::Num(n) => write!(f, "{}", format_number(*n)),
2153 Value::Complex(re, im) => {
2154 if *im == 0.0 {
2155 write!(f, "{}", format_number(*re))
2156 } else if *re == 0.0 {
2157 write!(f, "{}i", format_number(*im))
2158 } else if *im < 0.0 {
2159 write!(f, "{}-{}i", format_number(*re), format_number(im.abs()))
2160 } else {
2161 write!(f, "{}+{}i", format_number(*re), format_number(*im))
2162 }
2163 }
2164 Value::Bool(b) => write!(f, "{}", if *b { 1 } else { 0 }),
2165 Value::LogicalArray(la) => write!(f, "{la}"),
2166 Value::String(s) => write!(f, "'{s}'"),
2167 Value::StringArray(sa) => write!(f, "{sa}"),
2168 Value::CharArray(ca) => write!(f, "{ca}"),
2169 Value::Tensor(m) => write!(f, "{m}"),
2170 Value::SparseTensor(m) => write!(f, "{m}"),
2171 Value::ComplexTensor(m) => write!(f, "{m}"),
2172 Value::Cell(ca) => ca.fmt(f),
2173
2174 Value::GpuTensor(h) => write!(
2175 f,
2176 "GpuTensor(shape={:?}, device={}, buffer={})",
2177 h.shape, h.device_id, h.buffer_id
2178 ),
2179 Value::Object(obj) => write!(f, "{}(props={})", obj.class_name, obj.properties.len()),
2180 Value::HandleObject(h) => {
2181 let ptr = unsafe { h.target.as_raw() } as usize;
2182 write!(
2183 f,
2184 "<handle {} @0x{:x} valid={}>",
2185 h.class_name, ptr, h.valid
2186 )
2187 }
2188 Value::Listener(l) => {
2189 let ptr = unsafe { l.target.as_raw() } as usize;
2190 write!(
2191 f,
2192 "<listener id={} {}@0x{:x} '{}' enabled={} valid={}>",
2193 l.id,
2194 l.class_name(),
2195 ptr,
2196 l.event_name,
2197 l.enabled,
2198 l.valid
2199 )
2200 }
2201 Value::Struct(st) => {
2202 write!(f, "struct {{")?;
2203 for (i, (key, val)) in st.fields.iter().enumerate() {
2204 if i > 0 {
2205 write!(f, ", ")?;
2206 }
2207 write!(f, "{}: {}", key, val)?;
2208 }
2209 write!(f, "}}")
2210 }
2211 Value::OutputList(values) => {
2212 write!(f, "[")?;
2213 for (i, value) in values.iter().enumerate() {
2214 if i > 0 {
2215 write!(f, ", ")?;
2216 }
2217 write!(f, "{}", value)?;
2218 }
2219 write!(f, "]")
2220 }
2221 Value::FunctionHandle(name)
2222 | Value::ExternalFunctionHandle(name)
2223 | Value::MethodFunctionHandle(name) => {
2224 write!(f, "@{name}")
2225 }
2226 Value::BoundFunctionHandle { name, .. } => write!(f, "@{name}"),
2227 Value::Closure(c) => write!(
2228 f,
2229 "<closure {} captures={}>",
2230 c.function_name,
2231 c.captures.len()
2232 ),
2233 Value::ClassRef(name) => write!(f, "<class {name}>"),
2234 Value::MException(e) => write!(
2235 f,
2236 "MException(identifier='{}', message='{}')",
2237 e.identifier, e.message
2238 ),
2239 }
2240 }
2241}
2242
2243impl fmt::Display for ComplexTensor {
2244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2245 match self.shape.len() {
2246 0 | 1 => {
2247 write!(f, "[")?;
2248 for (i, (re, im)) in self.data.iter().enumerate() {
2249 if i > 0 {
2250 write!(f, " ")?;
2251 }
2252 let s = Value::Complex(*re, *im).to_string();
2253 write!(f, "{s}")?;
2254 }
2255 write!(f, "]")
2256 }
2257 2 => {
2258 let rows = self.rows;
2259 let cols = self.cols;
2260 write!(f, "[")?;
2261 for r in 0..rows {
2262 for c in 0..cols {
2263 if c > 0 {
2264 write!(f, " ")?;
2265 }
2266 let (re, im) = self.data[r + c * rows];
2267 let s = Value::Complex(re, im).to_string();
2268 write!(f, "{s}")?;
2269 }
2270 if r + 1 < rows {
2271 write!(f, "; ")?;
2272 }
2273 }
2274 write!(f, "]")
2275 }
2276 _ => {
2277 if should_expand_nd_display(&self.shape) {
2278 write_nd_pages(f, &self.shape, |f, idx| {
2279 let (re, im) = self.data[idx];
2280 write!(f, "{}", Value::Complex(re, im))
2281 })
2282 } else {
2283 write!(f, "ComplexTensor(shape={:?})", self.shape)
2284 }
2285 }
2286 }
2287 }
2288}
2289
2290#[cfg(test)]
2291mod display_tests {
2292 use super::{
2293 fmt_rational, format_number, set_display_format, ComplexTensor, FormatMode, LogicalArray,
2294 Tensor,
2295 };
2296
2297 #[test]
2298 fn fmt_rational_large_value_with_tiny_fract_does_not_overflow() {
2299 let result = std::panic::catch_unwind(|| fmt_rational(1_000_000_000_000_000.000_1));
2302 assert!(
2303 result.is_ok(),
2304 "fmt_rational panicked on large value with tiny fract"
2305 );
2306
2307 let result = std::panic::catch_unwind(|| fmt_rational(-1_000_000_000_000_000.000_1));
2309 assert!(
2310 result.is_ok(),
2311 "fmt_rational panicked on negative large value with tiny fract"
2312 );
2313 }
2314
2315 #[test]
2316 fn tensor_nd_display_uses_page_headers() {
2317 let tensor = Tensor::new(
2318 vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2319 vec![2, 3, 2],
2320 )
2321 .expect("tensor");
2322 let rendered = tensor.to_string();
2323 assert!(rendered.contains("(:, :, 1) ="));
2324 assert!(rendered.contains("(:, :, 2) ="));
2325 assert!(rendered.contains(" 1 0 0"));
2326 }
2327
2328 #[test]
2329 fn tensor_nd_display_falls_back_for_large_arrays() {
2330 let tensor = Tensor::new(vec![0.0; 4097], vec![1, 1, 4097]).expect("tensor");
2331 assert_eq!(tensor.to_string(), "Tensor(shape=[1, 1, 4097])");
2332 }
2333
2334 #[test]
2335 fn logical_nd_display_uses_headers_and_fallback_summary() {
2336 let logical =
2337 LogicalArray::new(vec![1, 0, 0, 1, 1, 0, 0, 1], vec![2, 2, 2]).expect("logical");
2338 let rendered = logical.to_string();
2339 assert!(rendered.contains("(:, :, 1) ="));
2340 assert!(rendered.contains("(:, :, 2) ="));
2341
2342 let large = LogicalArray::new(vec![1; 4097], vec![1, 1, 4097]).expect("large logical");
2343 assert_eq!(large.to_string(), "1x1x4097 logical array");
2344 }
2345
2346 #[test]
2347 fn complex_nd_display_uses_page_headers() {
2348 let complex = ComplexTensor::new(
2349 vec![(1.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0)],
2350 vec![2, 1, 2],
2351 )
2352 .expect("complex");
2353 let rendered = complex.to_string();
2354 assert!(rendered.contains("(:, :, 1) ="));
2355 assert!(rendered.contains("(:, :, 2) ="));
2356 }
2357
2358 #[test]
2359 fn format_hex_preserves_negative_zero_sign_bit() {
2360 set_display_format(FormatMode::Hex);
2361 assert_eq!(format_number(-0.0), "8000000000000000");
2362 assert_eq!(format_number(0.0), "0000000000000000");
2363 set_display_format(FormatMode::Short);
2364 }
2365}
2366
2367#[derive(Debug, Clone, PartialEq)]
2368pub struct CellArray {
2369 pub data: Vec<GcPtr<Value>>,
2370 pub shape: Vec<usize>,
2372 pub rows: usize,
2374 pub cols: usize,
2376}
2377
2378impl CellArray {
2379 pub fn new_handles(
2380 handles: Vec<GcPtr<Value>>,
2381 rows: usize,
2382 cols: usize,
2383 ) -> Result<Self, String> {
2384 Self::new_handles_with_shape(handles, vec![rows, cols])
2385 }
2386
2387 pub fn new_handles_with_shape(
2388 handles: Vec<GcPtr<Value>>,
2389 shape: Vec<usize>,
2390 ) -> Result<Self, String> {
2391 let expected = total_len(&shape)
2392 .ok_or_else(|| "Cell data shape exceeds platform limits".to_string())?;
2393 if expected != handles.len() {
2394 return Err(format!(
2395 "Cell data length {} doesn't match shape {:?} ({} elements)",
2396 handles.len(),
2397 shape,
2398 expected
2399 ));
2400 }
2401 let (rows, cols) = shape_rows_cols(&shape);
2402 Ok(CellArray {
2403 data: handles,
2404 shape,
2405 rows,
2406 cols,
2407 })
2408 }
2409
2410 pub fn new(data: Vec<Value>, rows: usize, cols: usize) -> Result<Self, String> {
2411 Self::new_with_shape(data, vec![rows, cols])
2412 }
2413
2414 pub fn new_with_shape(data: Vec<Value>, shape: Vec<usize>) -> Result<Self, String> {
2415 let expected = total_len(&shape)
2416 .ok_or_else(|| "Cell data shape exceeds platform limits".to_string())?;
2417 if expected != data.len() {
2418 return Err(format!(
2419 "Cell data length {} doesn't match shape {:?} ({} elements)",
2420 data.len(),
2421 shape,
2422 expected
2423 ));
2424 }
2425 let handles: Vec<GcPtr<Value>> = data
2427 .into_iter()
2428 .map(|v| unsafe { GcPtr::from_raw(Box::into_raw(Box::new(v))) })
2429 .collect();
2430 Self::new_handles_with_shape(handles, shape)
2431 }
2432
2433 pub fn get(&self, row: usize, col: usize) -> Result<Value, String> {
2434 if row >= self.rows || col >= self.cols {
2435 return Err(format!(
2436 "Cell index ({row}, {col}) out of bounds for {}x{} cell array",
2437 self.rows, self.cols
2438 ));
2439 }
2440 Ok((*self.data[row * self.cols + col]).clone())
2441 }
2442}
2443
2444fn total_len(shape: &[usize]) -> Option<usize> {
2445 if shape.is_empty() {
2446 return Some(0);
2447 }
2448 shape
2449 .iter()
2450 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
2451}
2452
2453fn shape_rows_cols(shape: &[usize]) -> (usize, usize) {
2454 if shape.is_empty() {
2455 return (0, 0);
2456 }
2457 if shape.len() == 1 {
2458 return (1, shape[0]);
2459 }
2460 (shape[0], shape[1])
2461}
2462
2463impl fmt::Display for CellArray {
2464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2465 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
2466 if self.shape.len() > 2 {
2467 return write!(f, "{} cell array", dims.join("x"));
2468 }
2469 write!(f, "{}x{} cell array", self.rows, self.cols)?;
2470 if self.rows == 0 || self.cols == 0 {
2471 return Ok(());
2472 }
2473 for r in 0..self.rows {
2474 writeln!(f)?;
2475 write!(f, " ")?;
2476 for c in 0..self.cols {
2477 if c > 0 {
2478 write!(f, " ")?;
2479 }
2480 let value = self.get(r, c).unwrap_or_else(|_| Value::Num(f64::NAN));
2481 write!(f, "{{{value}}}")?;
2482 }
2483 }
2484 Ok(())
2485 }
2486}
2487
2488#[derive(Debug, Clone, PartialEq)]
2489pub struct ObjectInstance {
2490 pub class_name: String,
2491 pub properties: HashMap<String, Value>,
2492}
2493
2494impl ObjectInstance {
2495 pub fn new(class_name: String) -> Self {
2496 Self {
2497 class_name,
2498 properties: HashMap::new(),
2499 }
2500 }
2501
2502 pub fn is_class(&self, name: &str) -> bool {
2503 self.class_name == name
2504 }
2505}
2506
2507#[derive(Debug, Clone, PartialEq, Eq)]
2509pub enum Access {
2510 Public,
2511 Private,
2512 Protected,
2513}
2514
2515#[derive(Debug, Clone)]
2516pub struct PropertyDef {
2517 pub name: String,
2518 pub is_static: bool,
2519 pub is_constant: bool,
2520 pub is_dependent: bool,
2521 pub get_access: Access,
2522 pub set_access: Access,
2523 pub default_value: Option<Value>,
2524}
2525
2526#[derive(Debug, Clone)]
2527pub struct MethodDef {
2528 pub name: String,
2529 pub is_static: bool,
2530 pub is_abstract: bool,
2531 pub is_sealed: bool,
2532 pub access: Access,
2533 pub function_name: String, pub implicit_class_argument: Option<String>,
2535}
2536
2537#[derive(Debug, Clone)]
2538pub struct ClassDef {
2539 pub name: String, pub parent: Option<String>,
2541 pub properties: HashMap<String, PropertyDef>,
2542 pub methods: HashMap<String, MethodDef>,
2543}
2544
2545use std::sync::Mutex;
2546
2547static CLASS_REGISTRY: OnceLock<Mutex<HashMap<String, ClassDef>>> = OnceLock::new();
2548static SEALED_CLASS_REGISTRY: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
2549static ABSTRACT_CLASS_REGISTRY: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
2550static STATIC_VALUES: OnceLock<Mutex<HashMap<(String, String), Value>>> = OnceLock::new();
2551static ENUMERATION_REGISTRY: OnceLock<Mutex<HashMap<String, HashSet<String>>>> = OnceLock::new();
2552
2553fn registry() -> &'static Mutex<HashMap<String, ClassDef>> {
2554 CLASS_REGISTRY.get_or_init(|| Mutex::new(primitive_class_registry()))
2555}
2556
2557fn sealed_registry() -> &'static Mutex<HashSet<String>> {
2558 SEALED_CLASS_REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
2559}
2560
2561fn abstract_registry() -> &'static Mutex<HashSet<String>> {
2562 ABSTRACT_CLASS_REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
2563}
2564
2565fn enumeration_registry() -> &'static Mutex<HashMap<String, HashSet<String>>> {
2566 ENUMERATION_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
2567}
2568
2569fn primitive_class_registry() -> HashMap<String, ClassDef> {
2570 ["double", "single", "logical"]
2571 .into_iter()
2572 .map(|class_name| {
2573 let mut methods = HashMap::new();
2574 methods.insert(
2575 "zeros".to_string(),
2576 MethodDef {
2577 name: "zeros".to_string(),
2578 is_static: true,
2579 is_abstract: false,
2580 is_sealed: false,
2581 access: Access::Public,
2582 function_name: "zeros".to_string(),
2583 implicit_class_argument: Some(class_name.to_string()),
2584 },
2585 );
2586 (
2587 class_name.to_string(),
2588 ClassDef {
2589 name: class_name.to_string(),
2590 parent: None,
2591 properties: HashMap::new(),
2592 methods,
2593 },
2594 )
2595 })
2596 .collect()
2597}
2598
2599pub fn register_class(def: ClassDef) {
2600 register_class_with_modifiers(def, false, false);
2601}
2602
2603pub fn register_class_with_sealed(def: ClassDef, is_sealed: bool) {
2604 register_class_with_modifiers(def, is_sealed, false);
2605}
2606
2607pub fn register_class_with_modifiers(def: ClassDef, is_sealed: bool, is_abstract: bool) {
2608 let mut m = registry().lock().unwrap();
2609 let class_name = def.name.clone();
2610 m.insert(class_name.clone(), def);
2611 let mut sealed = sealed_registry().lock().unwrap();
2612 if is_sealed {
2613 sealed.insert(class_name.clone());
2614 } else {
2615 sealed.remove(&class_name);
2616 }
2617 let mut abstract_classes = abstract_registry().lock().unwrap();
2618 if is_abstract {
2619 abstract_classes.insert(class_name.clone());
2620 } else {
2621 abstract_classes.remove(&class_name);
2622 }
2623 enumeration_registry()
2624 .lock()
2625 .unwrap()
2626 .entry(class_name)
2627 .or_default();
2628}
2629
2630pub fn register_class_enumerations(class_name: &str, members: impl IntoIterator<Item = String>) {
2631 let mut registry = enumeration_registry().lock().unwrap();
2632 let entry = registry.entry(class_name.to_string()).or_default();
2633 entry.clear();
2634 entry.extend(members);
2635}
2636
2637pub fn class_has_enumeration_member(class_name: &str, member: &str) -> bool {
2638 enumeration_registry()
2639 .lock()
2640 .unwrap()
2641 .get(class_name)
2642 .is_some_and(|members| members.contains(member))
2643}
2644
2645pub fn get_class(name: &str) -> Option<ClassDef> {
2646 registry().lock().unwrap().get(name).cloned()
2647}
2648
2649pub fn class_names() -> Vec<String> {
2650 registry().lock().unwrap().keys().cloned().collect()
2651}
2652
2653pub fn is_class_sealed(name: &str) -> bool {
2654 sealed_registry().lock().unwrap().contains(name)
2655}
2656
2657pub fn is_class_abstract(name: &str) -> bool {
2658 abstract_registry().lock().unwrap().contains(name)
2659}
2660
2661pub fn is_class_or_subclass(class_name: &str, ancestor_name: &str) -> bool {
2662 if class_name == ancestor_name {
2663 return true;
2664 }
2665 let reg = registry().lock().unwrap();
2666 let mut current = Some(class_name.to_string());
2667 let mut visited = std::collections::HashSet::new();
2668 while let Some(name) = current {
2669 if !visited.insert(name.clone()) {
2670 break;
2671 }
2672 if name == ancestor_name {
2673 return true;
2674 }
2675 current = reg
2676 .get(&name)
2677 .and_then(|class_def| class_def.parent.clone());
2678 }
2679 false
2680}
2681
2682pub fn lookup_property(class_name: &str, prop: &str) -> Option<(PropertyDef, String)> {
2685 let reg = registry().lock().unwrap();
2686 let mut current = Some(class_name.to_string());
2687 let mut visited = std::collections::HashSet::new();
2688 while let Some(name) = current {
2689 if !visited.insert(name.clone()) {
2690 break;
2691 }
2692 if let Some(cls) = reg.get(&name) {
2693 if let Some(p) = cls.properties.get(prop) {
2694 return Some((p.clone(), name));
2695 }
2696 current = cls.parent.clone();
2697 } else {
2698 break;
2699 }
2700 }
2701 None
2702}
2703
2704pub fn lookup_method(class_name: &str, method: &str) -> Option<(MethodDef, String)> {
2707 let reg = registry().lock().unwrap();
2708 let mut current = Some(class_name.to_string());
2709 let mut visited = std::collections::HashSet::new();
2710 while let Some(name) = current {
2711 if !visited.insert(name.clone()) {
2712 break;
2713 }
2714 if let Some(cls) = reg.get(&name) {
2715 if let Some(m) = cls.methods.get(method) {
2716 return Some((m.clone(), name));
2717 }
2718 current = cls.parent.clone();
2719 } else {
2720 break;
2721 }
2722 }
2723 None
2724}
2725
2726fn static_values() -> &'static Mutex<HashMap<(String, String), Value>> {
2727 STATIC_VALUES.get_or_init(|| Mutex::new(HashMap::new()))
2728}
2729
2730pub fn get_static_property_value(class_name: &str, prop: &str) -> Option<Value> {
2731 static_values()
2732 .lock()
2733 .unwrap()
2734 .get(&(class_name.to_string(), prop.to_string()))
2735 .cloned()
2736}
2737
2738pub fn set_static_property_value(class_name: &str, prop: &str, value: Value) {
2739 static_values()
2740 .lock()
2741 .unwrap()
2742 .insert((class_name.to_string(), prop.to_string()), value);
2743}
2744
2745pub fn set_static_property_value_in_owner(
2747 class_name: &str,
2748 prop: &str,
2749 value: Value,
2750) -> Result<(), String> {
2751 if let Some((_p, owner)) = lookup_property(class_name, prop) {
2752 set_static_property_value(&owner, prop, value);
2753 Ok(())
2754 } else {
2755 Err(format!("Unknown static property '{class_name}.{prop}'"))
2756 }
2757}
2758
2759#[cfg(test)]
2760mod class_registry_tests {
2761 use super::{
2762 get_class, lookup_method, lookup_property, register_class, Access, ClassDef, MethodDef,
2763 PropertyDef,
2764 };
2765 use std::collections::HashMap;
2766 use std::sync::atomic::{AtomicU64, Ordering};
2767
2768 static TEST_CLASS_COUNTER: AtomicU64 = AtomicU64::new(0);
2769
2770 fn unique_class_name(prefix: &str) -> String {
2771 let id = TEST_CLASS_COUNTER.fetch_add(1, Ordering::Relaxed);
2772 format!("{}_{}", prefix, id)
2773 }
2774
2775 #[test]
2776 fn primitive_classes_expose_static_zeros_method_metadata() {
2777 for class_name in ["double", "single", "logical"] {
2778 let class_def = get_class(class_name).expect("primitive class should be registered");
2779 let method = class_def
2780 .methods
2781 .get("zeros")
2782 .expect("primitive class should expose zeros static method");
2783 assert!(method.is_static, "zeros should be static on {class_name}");
2784 assert_eq!(method.function_name, "zeros");
2785 assert_eq!(method.implicit_class_argument.as_deref(), Some(class_name));
2786
2787 let (resolved, owner) =
2788 lookup_method(class_name, "zeros").expect("lookup should find primitive zeros");
2789 assert_eq!(owner, class_name);
2790 assert_eq!(resolved.function_name, "zeros");
2791 assert_eq!(
2792 resolved.implicit_class_argument.as_deref(),
2793 Some(class_name)
2794 );
2795 }
2796 }
2797
2798 #[test]
2799 fn method_lookup_uses_parent_class_metadata_chain() {
2800 let parent_name = unique_class_name("plan6_parent");
2801 let child_name = unique_class_name("plan6_child");
2802
2803 let mut parent_methods = HashMap::new();
2804 parent_methods.insert(
2805 "parentOnly".to_string(),
2806 MethodDef {
2807 name: "parentOnly".to_string(),
2808 is_static: false,
2809 is_abstract: false,
2810 is_sealed: false,
2811 access: Access::Public,
2812 function_name: "parentOnly_impl".to_string(),
2813 implicit_class_argument: None,
2814 },
2815 );
2816 register_class(ClassDef {
2817 name: parent_name.clone(),
2818 parent: None,
2819 properties: HashMap::new(),
2820 methods: parent_methods,
2821 });
2822 register_class(ClassDef {
2823 name: child_name.clone(),
2824 parent: Some(parent_name.clone()),
2825 properties: HashMap::new(),
2826 methods: HashMap::new(),
2827 });
2828
2829 let (method, owner) = lookup_method(&child_name, "parentOnly")
2830 .expect("child lookup should resolve inherited method through parent metadata");
2831 assert_eq!(owner, parent_name);
2832 assert_eq!(method.function_name, "parentOnly_impl");
2833 }
2834
2835 #[test]
2836 fn method_lookup_handles_parent_cycle() {
2837 let class_a = unique_class_name("plan6_cycle_method_a");
2838 let class_b = unique_class_name("plan6_cycle_method_b");
2839
2840 register_class(ClassDef {
2841 name: class_a.clone(),
2842 parent: Some(class_b.clone()),
2843 properties: HashMap::new(),
2844 methods: HashMap::new(),
2845 });
2846 register_class(ClassDef {
2847 name: class_b.clone(),
2848 parent: Some(class_a.clone()),
2849 properties: HashMap::new(),
2850 methods: HashMap::new(),
2851 });
2852
2853 assert!(
2854 lookup_method(&class_a, "missing").is_none(),
2855 "cyclic parent metadata should terminate missing method lookup"
2856 );
2857 }
2858
2859 #[test]
2860 fn property_lookup_uses_parent_class_metadata_chain() {
2861 let parent_name = unique_class_name("plan6_property_parent");
2862 let child_name = unique_class_name("plan6_property_child");
2863
2864 let mut parent_properties = HashMap::new();
2865 parent_properties.insert(
2866 "parentFlag".to_string(),
2867 PropertyDef {
2868 name: "parentFlag".to_string(),
2869 is_static: false,
2870 is_constant: false,
2871 is_dependent: false,
2872 get_access: Access::Public,
2873 set_access: Access::Public,
2874 default_value: None,
2875 },
2876 );
2877 register_class(ClassDef {
2878 name: parent_name.clone(),
2879 parent: None,
2880 properties: parent_properties,
2881 methods: HashMap::new(),
2882 });
2883 register_class(ClassDef {
2884 name: child_name.clone(),
2885 parent: Some(parent_name.clone()),
2886 properties: HashMap::new(),
2887 methods: HashMap::new(),
2888 });
2889
2890 let (property, owner) = lookup_property(&child_name, "parentFlag")
2891 .expect("child property lookup should resolve inherited property through parent");
2892 assert_eq!(owner, parent_name);
2893 assert_eq!(property.name, "parentFlag");
2894 assert!(!property.is_static);
2895 }
2896
2897 #[test]
2898 fn property_lookup_handles_parent_cycle() {
2899 let class_a = unique_class_name("plan6_cycle_property_a");
2900 let class_b = unique_class_name("plan6_cycle_property_b");
2901
2902 register_class(ClassDef {
2903 name: class_a.clone(),
2904 parent: Some(class_b.clone()),
2905 properties: HashMap::new(),
2906 methods: HashMap::new(),
2907 });
2908 register_class(ClassDef {
2909 name: class_b.clone(),
2910 parent: Some(class_a.clone()),
2911 properties: HashMap::new(),
2912 methods: HashMap::new(),
2913 });
2914
2915 assert!(
2916 lookup_property(&class_a, "missing").is_none(),
2917 "cyclic parent metadata should terminate missing property lookup"
2918 );
2919 }
2920}