1pub use inventory;
2pub mod symbolic;
3use runmat_gc_api::GcPtr;
4use runmat_thread_local::runmat_thread_local;
5use std::cell::RefCell;
6use std::collections::HashMap;
7use std::collections::HashSet;
8use std::convert::TryFrom;
9use std::fmt;
10use std::future::Future;
11use std::pin::Pin;
12pub use symbolic::{SymbolicExpr, SymbolicFunction};
13
14use indexmap::IndexMap;
15use std::sync::OnceLock;
16
17#[cfg(target_arch = "wasm32")]
18pub mod wasm_registry {
19 use super::{BuiltinDoc, BuiltinFunction, Constant};
20 use once_cell::sync::Lazy;
21 use std::sync::Mutex;
22
23 static FUNCTIONS: Lazy<Mutex<Vec<&'static BuiltinFunction>>> =
24 Lazy::new(|| Mutex::new(Vec::new()));
25 static CONSTANTS: Lazy<Mutex<Vec<&'static Constant>>> = Lazy::new(|| Mutex::new(Vec::new()));
26 static DOCS: Lazy<Mutex<Vec<&'static BuiltinDoc>>> = Lazy::new(|| Mutex::new(Vec::new()));
27 static REGISTERED: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
28
29 fn leak<T>(value: T) -> &'static T {
30 Box::leak(Box::new(value))
31 }
32
33 pub fn submit_builtin_function(func: BuiltinFunction) {
34 let leaked = leak(func);
35 FUNCTIONS.lock().unwrap().push(leaked);
36 }
37
38 pub fn submit_constant(constant: Constant) {
39 let leaked = leak(constant);
40 CONSTANTS.lock().unwrap().push(leaked);
41 }
42
43 pub fn submit_builtin_doc(doc: BuiltinDoc) {
44 let leaked = leak(doc);
45 DOCS.lock().unwrap().push(leaked);
46 }
47
48 pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
49 FUNCTIONS.lock().unwrap().clone()
50 }
51
52 pub fn constants() -> Vec<&'static Constant> {
53 CONSTANTS.lock().unwrap().clone()
54 }
55
56 pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
57 DOCS.lock().unwrap().clone()
58 }
59
60 pub fn mark_registered() {
61 *REGISTERED.lock().unwrap() = true;
62 }
63
64 pub fn is_registered() -> bool {
65 *REGISTERED.lock().unwrap()
66 }
67}
68
69#[derive(Debug, Clone, PartialEq)]
70pub enum Value {
71 Int(IntValue),
72 Num(f64),
73 Complex(f64, f64),
75 Bool(bool),
76 LogicalArray(LogicalArray),
78 String(String),
79 StringArray(StringArray),
81 CharArray(CharArray),
83 Tensor(Tensor),
84 SparseTensor(SparseTensor),
86 ComplexTensor(ComplexTensor),
88 Symbolic(SymbolicExpr),
90 Cell(CellArray),
91 Struct(StructValue),
94 GpuTensor(runmat_accelerate_api::GpuTensorHandle),
96 Object(ObjectInstance),
98 HandleObject(HandleRef),
100 Listener(Listener),
102 OutputList(Vec<Value>),
104 FunctionHandle(String),
106 ExternalFunctionHandle(String),
108 MethodFunctionHandle(String),
110 BoundFunctionHandle {
112 name: String,
113 function: usize,
114 },
115 Closure(Closure),
116 ClassRef(String),
117 MException(MException),
118}
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub enum IntValue {
121 I8(i8),
122 I16(i16),
123 I32(i32),
124 I64(i64),
125 U8(u8),
126 U16(u16),
127 U32(u32),
128 U64(u64),
129}
130
131impl IntValue {
132 pub fn to_i64(&self) -> i64 {
133 match self {
134 IntValue::I8(v) => *v as i64,
135 IntValue::I16(v) => *v as i64,
136 IntValue::I32(v) => *v as i64,
137 IntValue::I64(v) => *v,
138 IntValue::U8(v) => *v as i64,
139 IntValue::U16(v) => *v as i64,
140 IntValue::U32(v) => *v as i64,
141 IntValue::U64(v) => {
142 if *v > i64::MAX as u64 {
143 i64::MAX
144 } else {
145 *v as i64
146 }
147 }
148 }
149 }
150 pub fn to_f64(&self) -> f64 {
151 self.to_i64() as f64
152 }
153 pub fn is_zero(&self) -> bool {
154 self.to_i64() == 0
155 }
156 pub fn class_name(&self) -> &'static str {
157 match self {
158 IntValue::I8(_) => "int8",
159 IntValue::I16(_) => "int16",
160 IntValue::I32(_) => "int32",
161 IntValue::I64(_) => "int64",
162 IntValue::U8(_) => "uint8",
163 IntValue::U16(_) => "uint16",
164 IntValue::U32(_) => "uint32",
165 IntValue::U64(_) => "uint64",
166 }
167 }
168}
169
170#[derive(Debug, Clone, PartialEq)]
171pub struct StructValue {
172 pub fields: IndexMap<String, Value>,
173}
174
175impl StructValue {
176 pub fn new() -> Self {
177 Self {
178 fields: IndexMap::new(),
179 }
180 }
181
182 pub fn insert(&mut self, name: impl Into<String>, value: Value) -> Option<Value> {
184 self.fields.insert(name.into(), value)
185 }
186
187 pub fn remove(&mut self, name: &str) -> Option<Value> {
189 self.fields.shift_remove(name)
190 }
191
192 pub fn field_names(&self) -> impl Iterator<Item = &String> {
194 self.fields.keys()
195 }
196}
197
198impl Default for StructValue {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
205pub enum NumericDType {
206 F64,
207 F32,
208 U8,
209 U16,
210}
211
212impl NumericDType {
213 pub fn class_name(self) -> &'static str {
214 match self {
215 NumericDType::F64 => "double",
216 NumericDType::F32 => "single",
217 NumericDType::U8 => "uint8",
218 NumericDType::U16 => "uint16",
219 }
220 }
221
222 pub fn byte_size(self) -> usize {
223 match self {
224 NumericDType::F64 => 8,
225 NumericDType::F32 => 4,
226 NumericDType::U8 => 1,
227 NumericDType::U16 => 2,
228 }
229 }
230}
231
232#[derive(Debug, Clone, PartialEq)]
233pub struct Tensor {
234 pub data: Vec<f64>,
235 pub shape: Vec<usize>, pub rows: usize, pub cols: usize, pub dtype: NumericDType,
240}
241
242#[derive(Debug, Clone, PartialEq)]
243pub struct SparseTensor {
244 pub rows: usize,
245 pub cols: usize,
246 pub col_ptrs: Vec<usize>,
248 pub row_indices: Vec<usize>,
250 pub values: Vec<f64>,
251}
252
253#[derive(Debug, Clone, PartialEq)]
254pub struct ComplexTensor {
255 pub data: Vec<(f64, f64)>,
256 pub shape: Vec<usize>,
257 pub rows: usize,
258 pub cols: usize,
259}
260
261#[derive(Debug, Clone, PartialEq)]
262pub struct StringArray {
263 pub data: Vec<String>,
264 pub shape: Vec<usize>,
265 pub rows: usize,
266 pub cols: usize,
267}
268
269#[derive(Debug, Clone, PartialEq)]
270pub struct LogicalArray {
271 pub data: Vec<u8>, pub shape: Vec<usize>,
273}
274
275impl LogicalArray {
276 pub fn new(data: Vec<u8>, shape: Vec<usize>) -> Result<Self, String> {
277 let expected: usize = shape.iter().product();
278 if data.len() != expected {
279 return Err(format!(
280 "LogicalArray data length {} doesn't match shape {:?} ({} elements)",
281 data.len(),
282 shape,
283 expected
284 ));
285 }
286 let mut d = data;
288 for v in &mut d {
289 *v = if *v != 0 { 1 } else { 0 };
290 }
291 Ok(LogicalArray { data: d, shape })
292 }
293 pub fn zeros(shape: Vec<usize>) -> Self {
294 let expected: usize = shape.iter().product();
295 LogicalArray {
296 data: vec![0u8; expected],
297 shape,
298 }
299 }
300 pub fn len(&self) -> usize {
301 self.data.len()
302 }
303 pub fn is_empty(&self) -> bool {
304 self.data.is_empty()
305 }
306}
307
308#[derive(Debug, Clone, PartialEq)]
309pub struct CharArray {
310 pub data: Vec<char>,
311 pub rows: usize,
312 pub cols: usize,
313}
314
315impl CharArray {
316 pub fn new_row(s: &str) -> Self {
317 CharArray {
318 data: s.chars().collect(),
319 rows: 1,
320 cols: s.chars().count(),
321 }
322 }
323 pub fn new(data: Vec<char>, rows: usize, cols: usize) -> Result<Self, String> {
324 if rows * cols != data.len() {
325 return Err(format!(
326 "Char data length {} doesn't match dimensions {}x{}",
327 data.len(),
328 rows,
329 cols
330 ));
331 }
332 Ok(CharArray { data, rows, cols })
333 }
334}
335
336impl StringArray {
337 pub fn new(data: Vec<String>, shape: Vec<usize>) -> Result<Self, String> {
338 let expected: usize = shape.iter().product();
339 if data.len() != expected {
340 return Err(format!(
341 "StringArray data length {} doesn't match shape {:?} ({} elements)",
342 data.len(),
343 shape,
344 expected
345 ));
346 }
347 let (rows, cols) = if shape.len() >= 2 {
348 (shape[0], shape[1])
349 } else if shape.len() == 1 {
350 (1, shape[0])
351 } else {
352 (0, 0)
353 };
354 Ok(StringArray {
355 data,
356 shape,
357 rows,
358 cols,
359 })
360 }
361 pub fn new_2d(data: Vec<String>, rows: usize, cols: usize) -> Result<Self, String> {
362 Self::new(data, vec![rows, cols])
363 }
364 pub fn rows(&self) -> usize {
365 self.shape.first().copied().unwrap_or(1)
366 }
367 pub fn cols(&self) -> usize {
368 self.shape.get(1).copied().unwrap_or(1)
369 }
370}
371
372impl Tensor {
375 pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Result<Self, String> {
376 let expected: usize = shape.iter().product();
377 if data.len() != expected {
378 return Err(format!(
379 "Tensor data length {} doesn't match shape {:?} ({} elements)",
380 data.len(),
381 shape,
382 expected
383 ));
384 }
385 let (rows, cols) = if shape.len() >= 2 {
386 (shape[0], shape[1])
387 } else if shape.len() == 1 {
388 (1, shape[0])
389 } else {
390 (0, 0)
391 };
392 Ok(Tensor {
393 data,
394 shape,
395 rows,
396 cols,
397 dtype: NumericDType::F64,
398 })
399 }
400
401 pub fn new_2d(data: Vec<f64>, rows: usize, cols: usize) -> Result<Self, String> {
402 Self::new(data, vec![rows, cols])
403 }
404
405 pub fn from_f32(data: Vec<f32>, shape: Vec<usize>) -> Result<Self, String> {
406 let converted: Vec<f64> = data.into_iter().map(|v| v as f64).collect();
407 Self::new_with_dtype(converted, shape, NumericDType::F32)
408 }
409
410 pub fn from_f32_slice(data: &[f32], shape: &[usize]) -> Result<Self, String> {
411 let converted: Vec<f64> = data.iter().map(|&v| v as f64).collect();
412 Self::new_with_dtype(converted, shape.to_vec(), NumericDType::F32)
413 }
414
415 pub fn new_with_dtype(
416 data: Vec<f64>,
417 shape: Vec<usize>,
418 dtype: NumericDType,
419 ) -> Result<Self, String> {
420 let mut t = Self::new(data, shape)?;
421 t.dtype = dtype;
422 Ok(t)
423 }
424
425 pub fn zeros(shape: Vec<usize>) -> Self {
426 let size: usize = shape.iter().product();
427 let (rows, cols) = if shape.len() >= 2 {
428 (shape[0], shape[1])
429 } else if shape.len() == 1 {
430 (1, shape[0])
431 } else {
432 (0, 0)
433 };
434 Tensor {
435 data: vec![0.0; size],
436 shape,
437 rows,
438 cols,
439 dtype: NumericDType::F64,
440 }
441 }
442
443 pub fn ones(shape: Vec<usize>) -> Self {
444 let size: usize = shape.iter().product();
445 let (rows, cols) = if shape.len() >= 2 {
446 (shape[0], shape[1])
447 } else if shape.len() == 1 {
448 (1, shape[0])
449 } else {
450 (0, 0)
451 };
452 Tensor {
453 data: vec![1.0; size],
454 shape,
455 rows,
456 cols,
457 dtype: NumericDType::F64,
458 }
459 }
460
461 pub fn zeros2(rows: usize, cols: usize) -> Self {
463 Self::zeros(vec![rows, cols])
464 }
465 pub fn ones2(rows: usize, cols: usize) -> Self {
466 Self::ones(vec![rows, cols])
467 }
468
469 pub fn rows(&self) -> usize {
470 self.shape.first().copied().unwrap_or(1)
471 }
472 pub fn cols(&self) -> usize {
473 self.shape.get(1).copied().unwrap_or(1)
474 }
475
476 pub fn get2(&self, row: usize, col: usize) -> Result<f64, String> {
477 let rows = self.rows();
478 let cols = self.cols();
479 if row >= rows || col >= cols {
480 return Err(format!(
481 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
482 ));
483 }
484 Ok(self.data[row + col * rows])
486 }
487
488 pub fn set2(&mut self, row: usize, col: usize, value: f64) -> Result<(), String> {
489 let rows = self.rows();
490 let cols = self.cols();
491 if row >= rows || col >= cols {
492 return Err(format!(
493 "Index ({row}, {col}) out of bounds for {rows}x{cols} tensor"
494 ));
495 }
496 self.data[row + col * rows] = value;
498 Ok(())
499 }
500
501 pub fn scalar_to_tensor2(scalar: f64, rows: usize, cols: usize) -> Tensor {
502 Tensor {
503 data: vec![scalar; rows * cols],
504 shape: vec![rows, cols],
505 rows,
506 cols,
507 dtype: NumericDType::F64,
508 }
509 }
510 }
512
513impl SparseTensor {
514 pub fn new(
515 rows: usize,
516 cols: usize,
517 col_ptrs: Vec<usize>,
518 row_indices: Vec<usize>,
519 values: Vec<f64>,
520 ) -> Result<Self, String> {
521 if col_ptrs.len() != cols.saturating_add(1) {
522 return Err(format!(
523 "SparseTensor col_ptrs length {} doesn't match cols {}",
524 col_ptrs.len(),
525 cols
526 ));
527 }
528 if row_indices.len() != values.len() {
529 return Err(format!(
530 "SparseTensor row index length {} doesn't match value length {}",
531 row_indices.len(),
532 values.len()
533 ));
534 }
535 if col_ptrs.first().copied().unwrap_or(usize::MAX) != 0 {
536 return Err("SparseTensor col_ptrs must start at 0".to_string());
537 }
538 if col_ptrs.last().copied().unwrap_or(usize::MAX) != values.len() {
539 return Err("SparseTensor final col_ptr must equal nnz".to_string());
540 }
541 for window in col_ptrs.windows(2) {
542 if window[0] > window[1] {
543 return Err("SparseTensor col_ptrs must be nondecreasing".to_string());
544 }
545 }
546 for col in 0..cols {
547 let start = col_ptrs[col];
548 let end = col_ptrs[col + 1];
549 let mut prev: Option<usize> = None;
550 for &row in &row_indices[start..end] {
551 if row >= rows {
552 return Err(format!("SparseTensor row index {row} exceeds rows {rows}"));
553 }
554 if prev.is_some_and(|p| p >= row) {
555 return Err("SparseTensor row indices must be sorted and unique".to_string());
556 }
557 prev = Some(row);
558 }
559 }
560 Ok(Self {
561 rows,
562 cols,
563 col_ptrs,
564 row_indices,
565 values,
566 })
567 }
568
569 pub fn zeros(rows: usize, cols: usize) -> Self {
570 Self {
571 rows,
572 cols,
573 col_ptrs: vec![0; cols.saturating_add(1)],
574 row_indices: Vec::new(),
575 values: Vec::new(),
576 }
577 }
578
579 pub fn nnz(&self) -> usize {
580 self.values.len()
581 }
582
583 pub fn shape(&self) -> Vec<usize> {
584 vec![self.rows, self.cols]
585 }
586
587 pub fn to_dense(&self) -> Result<Tensor, String> {
588 let len = self
589 .rows
590 .checked_mul(self.cols)
591 .ok_or_else(|| "SparseTensor dense dimensions overflow usize".to_string())?;
592 let mut data = Vec::new();
593 data.try_reserve_exact(len)
594 .map_err(|err| format!("SparseTensor dense allocation failed: {err}"))?;
595 data.resize(len, 0.0);
596 for col in 0..self.cols {
597 for idx in self.col_ptrs[col]..self.col_ptrs[col + 1] {
598 let row = self.row_indices[idx];
599 data[row + col * self.rows] = self.values[idx];
600 }
601 }
602 Tensor::new(data, self.shape())
603 }
604
605 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
606 if row >= self.rows || col >= self.cols {
607 return None;
608 }
609 let start = self.col_ptrs[col];
610 let end = self.col_ptrs[col + 1];
611 self.row_indices[start..end]
612 .binary_search(&row)
613 .ok()
614 .map(|offset| self.values[start + offset])
615 }
616}
617
618#[cfg(test)]
619mod sparse_tensor_tests {
620 use super::*;
621
622 #[test]
623 fn to_dense_rejects_overflowing_dimensions() {
624 let sparse = SparseTensor {
625 rows: usize::MAX,
626 cols: 2,
627 col_ptrs: vec![0, 0, 0],
628 row_indices: Vec::new(),
629 values: Vec::new(),
630 };
631
632 let err = sparse.to_dense().unwrap_err();
633 assert!(err.contains("overflow"));
634 }
635}
636
637impl ComplexTensor {
638 pub fn new(data: Vec<(f64, f64)>, shape: Vec<usize>) -> Result<Self, String> {
639 let expected: usize = shape.iter().product();
640 if data.len() != expected {
641 return Err(format!(
642 "ComplexTensor data length {} doesn't match shape {:?} ({} elements)",
643 data.len(),
644 shape,
645 expected
646 ));
647 }
648 let (rows, cols) = if shape.len() >= 2 {
649 (shape[0], shape[1])
650 } else if shape.len() == 1 {
651 (1, shape[0])
652 } else {
653 (0, 0)
654 };
655 Ok(ComplexTensor {
656 data,
657 shape,
658 rows,
659 cols,
660 })
661 }
662 pub fn new_2d(data: Vec<(f64, f64)>, rows: usize, cols: usize) -> Result<Self, String> {
663 Self::new(data, vec![rows, cols])
664 }
665 pub fn zeros(shape: Vec<usize>) -> Self {
666 let size: usize = shape.iter().product();
667 let (rows, cols) = if shape.len() >= 2 {
668 (shape[0], shape[1])
669 } else if shape.len() == 1 {
670 (1, shape[0])
671 } else {
672 (0, 0)
673 };
674 ComplexTensor {
675 data: vec![(0.0, 0.0); size],
676 shape,
677 rows,
678 cols,
679 }
680 }
681}
682
683const MAX_ND_DISPLAY_ELEMENTS: usize = 4096;
684
685fn should_expand_nd_display(shape: &[usize]) -> bool {
686 shape.len() > 2
687 && matches!(
688 total_len(shape),
689 Some(total) if total > 0 && total <= MAX_ND_DISPLAY_ELEMENTS
690 )
691}
692
693fn column_major_strides(shape: &[usize]) -> Vec<usize> {
694 let mut strides = Vec::with_capacity(shape.len());
695 let mut stride = 1usize;
696 for &dim in shape {
697 strides.push(stride);
698 stride = stride.saturating_mul(dim);
699 }
700 strides
701}
702
703fn decode_page_coords(mut page_index: usize, page_shape: &[usize]) -> Vec<usize> {
704 let mut coords = Vec::with_capacity(page_shape.len());
705 for &dim in page_shape {
706 if dim == 0 {
707 coords.push(0);
708 } else {
709 coords.push(page_index % dim);
710 page_index /= dim;
711 }
712 }
713 coords
714}
715
716fn write_nd_pages(
717 f: &mut fmt::Formatter<'_>,
718 shape: &[usize],
719 mut write_element: impl FnMut(&mut fmt::Formatter<'_>, usize) -> fmt::Result,
720) -> fmt::Result {
721 if shape.len() <= 2 {
722 return Ok(());
723 }
724 let rows = shape[0];
725 let cols = shape[1];
726 if rows == 0 || cols == 0 {
727 return write!(f, "[]");
728 }
729 let Some(page_count) = total_len(&shape[2..]) else {
730 return write!(f, "Tensor(shape={shape:?})");
731 };
732 if page_count == 0 {
733 return write!(f, "[]");
734 }
735 let strides = column_major_strides(shape);
736 for page_index in 0..page_count {
737 if page_index > 0 {
738 write!(f, "\n\n")?;
739 }
740 let coords = decode_page_coords(page_index, &shape[2..]);
741 write!(f, "(:, :")?;
742 for &coord in &coords {
743 write!(f, ", {}", coord + 1)?;
744 }
745 write!(f, ") =")?;
746
747 let mut page_base = 0usize;
748 for (offset, &coord) in coords.iter().enumerate() {
749 page_base += coord * strides[offset + 2];
750 }
751 for r in 0..rows {
752 writeln!(f)?;
753 write!(f, " ")?;
754 for c in 0..cols {
755 if c > 0 {
756 write!(f, " ")?;
757 }
758 let linear = page_base + r + c * rows;
759 write_element(f, linear)?;
760 }
761 }
762 }
763 Ok(())
764}
765
766impl fmt::Display for Tensor {
767 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768 match self.shape.len() {
769 0 | 1 => {
770 write!(f, "[")?;
772 for (i, v) in self.data.iter().enumerate() {
773 if i > 0 {
774 write!(f, " ")?;
775 }
776 write!(f, "{}", format_number(*v))?;
777 }
778 write!(f, "]")
779 }
780 2 => {
781 let rows = self.rows();
782 let cols = self.cols();
783 for r in 0..rows {
785 writeln!(f)?;
786 write!(f, " ")?; for c in 0..cols {
788 if c > 0 {
789 write!(f, " ")?;
790 }
791 let v = self.data[r + c * rows];
792 write!(f, "{}", format_number(v))?;
793 }
794 }
795 Ok(())
796 }
797 _ => {
798 if should_expand_nd_display(&self.shape) {
799 write_nd_pages(f, &self.shape, |f, idx| {
800 write!(f, "{}", format_number(self.data[idx]))
801 })
802 } else {
803 write!(f, "Tensor(shape={:?})", self.shape)
804 }
805 }
806 }
807 }
808}
809
810impl fmt::Display for SparseTensor {
811 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
812 writeln!(
813 f,
814 "{}x{} sparse double matrix with {} nonzero entries",
815 self.rows,
816 self.cols,
817 self.nnz()
818 )?;
819 if self.nnz() == 0 {
820 return Ok(());
821 }
822 for col in 0..self.cols {
823 for idx in self.col_ptrs[col]..self.col_ptrs[col + 1] {
824 let row = self.row_indices[idx];
825 writeln!(
826 f,
827 " ({},{}) {}",
828 row + 1,
829 col + 1,
830 format_number(self.values[idx])
831 )?;
832 }
833 }
834 Ok(())
835 }
836}
837
838impl fmt::Display for StringArray {
839 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
840 let (rows, cols) = match self.shape.len() {
841 0 => (0, 0),
842 1 => (1, self.shape[0]),
843 _ => (self.shape[0], self.shape[1]),
844 };
845 let count = self.data.len();
846 if count == 1 && rows == 1 && cols == 1 {
847 let v = &self.data[0];
848 if v == "<missing>" {
849 return write!(f, "<missing>");
850 }
851 let escaped = v.replace('"', "\\\"");
852 return write!(f, "\"{escaped}\"");
853 }
854 if self.shape.len() > 2 {
855 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
856 return write!(f, "{} string array", dims.join("x"));
857 }
858 write!(f, "{rows}x{cols} string array")?;
859 if rows == 0 || cols == 0 {
860 return Ok(());
861 }
862 for r in 0..rows {
863 writeln!(f)?;
864 write!(f, " ")?;
865 for c in 0..cols {
866 if c > 0 {
867 write!(f, " ")?;
868 }
869 let v = &self.data[r + c * rows];
870 if v == "<missing>" {
871 write!(f, "<missing>")?;
872 } else {
873 let escaped = v.replace('"', "\\\"");
874 write!(f, "\"{escaped}\"")?;
875 }
876 }
877 }
878 Ok(())
879 }
880}
881
882impl fmt::Display for LogicalArray {
883 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
884 if self.data.len() == 1 {
885 return write!(f, "{}", if self.data[0] != 0 { 1 } else { 0 });
886 }
887 match self.shape.len() {
888 0 => write!(f, "[]"),
889 1 => {
890 write!(f, "[")?;
891 for (i, v) in self.data.iter().enumerate() {
892 if i > 0 {
893 write!(f, " ")?;
894 }
895 write!(f, "{}", if *v != 0 { 1 } else { 0 })?;
896 }
897 write!(f, "]")
898 }
899 2 => {
900 let rows = self.shape[0];
901 let cols = self.shape[1];
902 for r in 0..rows {
904 writeln!(f)?;
905 write!(f, " ")?; for c in 0..cols {
907 if c > 0 {
908 write!(f, " ")?;
909 }
910 let idx = r + c * rows;
911 write!(f, "{}", if self.data[idx] != 0 { 1 } else { 0 })?;
912 }
913 }
914 Ok(())
915 }
916 _ => {
917 if should_expand_nd_display(&self.shape) {
918 write_nd_pages(f, &self.shape, |f, idx| {
919 write!(f, "{}", if self.data[idx] != 0 { 1 } else { 0 })
920 })
921 } else {
922 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
923 write!(f, "{} logical array", dims.join("x"))
924 }
925 }
926 }
927 }
928}
929
930impl fmt::Display for CharArray {
931 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
932 for r in 0..self.rows {
933 writeln!(f)?;
934 write!(f, " ")?; for c in 0..self.cols {
936 let ch = self.data[r * self.cols + c];
937 write!(f, "{ch}")?;
938 }
939 }
940 Ok(())
941 }
942}
943
944impl From<i32> for Value {
946 fn from(i: i32) -> Self {
947 Value::Int(IntValue::I32(i))
948 }
949}
950impl From<i64> for Value {
951 fn from(i: i64) -> Self {
952 Value::Int(IntValue::I64(i))
953 }
954}
955impl From<u32> for Value {
956 fn from(i: u32) -> Self {
957 Value::Int(IntValue::U32(i))
958 }
959}
960impl From<u64> for Value {
961 fn from(i: u64) -> Self {
962 Value::Int(IntValue::U64(i))
963 }
964}
965impl From<i16> for Value {
966 fn from(i: i16) -> Self {
967 Value::Int(IntValue::I16(i))
968 }
969}
970impl From<i8> for Value {
971 fn from(i: i8) -> Self {
972 Value::Int(IntValue::I8(i))
973 }
974}
975impl From<u16> for Value {
976 fn from(i: u16) -> Self {
977 Value::Int(IntValue::U16(i))
978 }
979}
980impl From<u8> for Value {
981 fn from(i: u8) -> Self {
982 Value::Int(IntValue::U8(i))
983 }
984}
985
986impl From<f64> for Value {
987 fn from(f: f64) -> Self {
988 Value::Num(f)
989 }
990}
991
992impl From<bool> for Value {
993 fn from(b: bool) -> Self {
994 Value::Bool(b)
995 }
996}
997
998impl From<String> for Value {
999 fn from(s: String) -> Self {
1000 Value::String(s)
1001 }
1002}
1003
1004impl From<&str> for Value {
1005 fn from(s: &str) -> Self {
1006 Value::String(s.to_string())
1007 }
1008}
1009
1010impl From<Tensor> for Value {
1011 fn from(m: Tensor) -> Self {
1012 Value::Tensor(m)
1013 }
1014}
1015
1016impl TryFrom<&Value> for i32 {
1020 type Error = String;
1021 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1022 match v {
1023 Value::Int(i) => Ok(i.to_i64() as i32),
1024 Value::Num(n) => Ok(*n as i32),
1025 _ => Err(format!("cannot convert {v:?} to i32")),
1026 }
1027 }
1028}
1029
1030impl TryFrom<&Value> for f64 {
1031 type Error = String;
1032 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1033 match v {
1034 Value::Num(n) => Ok(*n),
1035 Value::Int(i) => Ok(i.to_f64()),
1036 _ => Err(format!("cannot convert {v:?} to f64")),
1037 }
1038 }
1039}
1040
1041impl TryFrom<&Value> for bool {
1042 type Error = String;
1043 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1044 match v {
1045 Value::Bool(b) => Ok(*b),
1046 Value::Int(i) => Ok(!i.is_zero()),
1047 Value::Num(n) => Ok(*n != 0.0),
1048 _ => Err(format!("cannot convert {v:?} to bool")),
1049 }
1050 }
1051}
1052
1053impl TryFrom<&Value> for String {
1054 type Error = String;
1055 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1056 match v {
1057 Value::String(s) => Ok(s.clone()),
1058 Value::StringArray(sa) => {
1059 if sa.data.len() == 1 {
1060 Ok(sa.data[0].clone())
1061 } else {
1062 Err("cannot convert string array to scalar string".to_string())
1063 }
1064 }
1065 Value::CharArray(ca) => {
1066 if ca.rows == 1 {
1068 Ok(ca.data.iter().collect())
1069 } else {
1070 Err("cannot convert multi-row char array to scalar string".to_string())
1071 }
1072 }
1073 Value::Int(i) => Ok(i.to_i64().to_string()),
1074 Value::Num(n) => Ok(n.to_string()),
1075 Value::Bool(b) => Ok(b.to_string()),
1076 _ => Err(format!("cannot convert {v:?} to String")),
1077 }
1078 }
1079}
1080
1081impl TryFrom<&Value> for Tensor {
1082 type Error = String;
1083 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1084 match v {
1085 Value::Tensor(m) => Ok(m.clone()),
1086 _ => Err(format!("cannot convert {v:?} to Tensor")),
1087 }
1088 }
1089}
1090
1091impl TryFrom<&Value> for Value {
1092 type Error = String;
1093 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1094 Ok(v.clone())
1095 }
1096}
1097
1098impl TryFrom<&Value> for Vec<Value> {
1099 type Error = String;
1100 fn try_from(v: &Value) -> Result<Self, Self::Error> {
1101 match v {
1102 Value::Cell(c) => Ok(c.data.iter().map(|p| (**p).clone()).collect()),
1103 _ => Err(format!("cannot convert {v:?} to Vec<Value>")),
1104 }
1105 }
1106}
1107
1108use serde::{Deserialize, Serialize};
1109
1110#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
1113pub enum Type {
1114 Int,
1116 Num,
1118 Bool,
1120 Logical {
1122 shape: Option<Vec<Option<usize>>>,
1124 },
1125 String,
1127 Tensor {
1129 shape: Option<Vec<Option<usize>>>,
1131 },
1132 Symbolic,
1134 Cell {
1136 element_type: Option<Box<Type>>,
1138 length: Option<usize>,
1140 },
1141 Function {
1143 params: Vec<Type>,
1145 returns: Box<Type>,
1147 },
1148 Void,
1150 Unknown,
1152 Union(Vec<Type>),
1154 Struct {
1156 known_fields: Option<Vec<String>>, },
1159 OutputList(Vec<Type>),
1161}
1162
1163impl Type {
1164 pub fn tensor() -> Self {
1166 Type::Tensor { shape: None }
1167 }
1168
1169 pub fn logical() -> Self {
1171 Type::Logical { shape: None }
1172 }
1173
1174 pub fn logical_with_shape(shape: Vec<usize>) -> Self {
1176 Type::Logical {
1177 shape: Some(shape.into_iter().map(Some).collect()),
1178 }
1179 }
1180
1181 pub fn tensor_with_shape(shape: Vec<usize>) -> Self {
1183 Type::Tensor {
1184 shape: Some(shape.into_iter().map(Some).collect()),
1185 }
1186 }
1187
1188 pub fn cell() -> Self {
1190 Type::Cell {
1191 element_type: None,
1192 length: None,
1193 }
1194 }
1195
1196 pub fn cell_of(element_type: Type) -> Self {
1198 Type::Cell {
1199 element_type: Some(Box::new(element_type)),
1200 length: None,
1201 }
1202 }
1203
1204 pub fn is_compatible_with(&self, other: &Type) -> bool {
1206 match (self, other) {
1207 (Type::Unknown, _) | (_, Type::Unknown) => true,
1208 (Type::Int, Type::Num) | (Type::Num, Type::Int) => true, (Type::Tensor { .. }, Type::Tensor { .. }) => true, (Type::OutputList(a), Type::OutputList(b)) => a.len() == b.len(),
1211 (a, b) => a == b,
1212 }
1213 }
1214
1215 pub fn unify(&self, other: &Type) -> Type {
1217 match (self, other) {
1218 (Type::Unknown, t) | (t, Type::Unknown) => t.clone(),
1219 (Type::Int, Type::Num) | (Type::Num, Type::Int) => Type::Num,
1220 (Type::Tensor { shape: a }, Type::Tensor { shape: b }) => {
1221 let a_norm = match a {
1222 Some(dims) if dims.is_empty() => None,
1223 _ => a.clone(),
1224 };
1225 let b_norm = match b {
1226 Some(dims) if dims.is_empty() => None,
1227 _ => b.clone(),
1228 };
1229 let a_unknown = a_norm
1230 .as_ref()
1231 .map(|dims| dims.iter().all(|d| d.is_none()))
1232 .unwrap_or(true);
1233 let b_unknown = b_norm
1234 .as_ref()
1235 .map(|dims| dims.iter().all(|d| d.is_none()))
1236 .unwrap_or(true);
1237 if a_norm == b_norm
1238 || (!a_unknown && b_unknown)
1239 || (a_norm.is_some() && b_norm.is_none())
1240 {
1241 Type::Tensor { shape: a_norm }
1242 } else if (a_unknown && !b_unknown) || (a_norm.is_none() && b_norm.is_some()) {
1243 Type::Tensor { shape: b_norm }
1244 } else {
1245 Type::tensor()
1246 }
1247 }
1248 (Type::Logical { shape: a }, Type::Logical { shape: b }) => {
1249 let a_norm = match a {
1250 Some(dims) if dims.is_empty() => None,
1251 _ => a.clone(),
1252 };
1253 let b_norm = match b {
1254 Some(dims) if dims.is_empty() => None,
1255 _ => b.clone(),
1256 };
1257 let a_unknown = a_norm
1258 .as_ref()
1259 .map(|dims| dims.iter().all(|d| d.is_none()))
1260 .unwrap_or(true);
1261 let b_unknown = b_norm
1262 .as_ref()
1263 .map(|dims| dims.iter().all(|d| d.is_none()))
1264 .unwrap_or(true);
1265 if a_norm == b_norm
1266 || (!a_unknown && b_unknown)
1267 || (a_norm.is_some() && b_norm.is_none())
1268 {
1269 Type::Logical { shape: a_norm }
1270 } else if (a_unknown && !b_unknown) || (a_norm.is_none() && b_norm.is_some()) {
1271 Type::Logical { shape: b_norm }
1272 } else {
1273 Type::logical()
1274 }
1275 }
1276 (Type::Struct { known_fields: a }, Type::Struct { known_fields: b }) => match (a, b) {
1277 (None, None) => Type::Struct { known_fields: None },
1278 (Some(ka), None) | (None, Some(ka)) => Type::Struct {
1279 known_fields: Some(ka.clone()),
1280 },
1281 (Some(ka), Some(kb)) => {
1282 let mut set: std::collections::BTreeSet<String> = ka.iter().cloned().collect();
1283 set.extend(kb.iter().cloned());
1284 Type::Struct {
1285 known_fields: Some(set.into_iter().collect()),
1286 }
1287 }
1288 },
1289 (Type::OutputList(a), Type::OutputList(b)) => {
1290 if a.len() == b.len() {
1291 let items = a
1292 .iter()
1293 .zip(b.iter())
1294 .map(|(lhs, rhs)| lhs.unify(rhs))
1295 .collect();
1296 Type::OutputList(items)
1297 } else {
1298 Type::OutputList(vec![Type::Unknown; a.len().max(b.len())])
1299 }
1300 }
1301 (a, b) if a == b => a.clone(),
1302 _ => Type::Union(vec![self.clone(), other.clone()]),
1303 }
1304 }
1305
1306 pub fn from_value(value: &Value) -> Type {
1308 match value {
1309 Value::Int(_) => Type::Int,
1310 Value::Num(_) => Type::Num,
1311 Value::Complex(_, _) => Type::Num, Value::Bool(_) => Type::Bool,
1313 Value::LogicalArray(arr) => Type::Logical {
1314 shape: Some(arr.shape.iter().map(|&d| Some(d)).collect()),
1315 },
1316 Value::String(_) => Type::String,
1317 Value::StringArray(_sa) => {
1318 Type::cell_of(Type::String)
1320 }
1321 Value::Tensor(t) => Type::Tensor {
1322 shape: Some(t.shape.iter().map(|&d| Some(d)).collect()),
1323 },
1324 Value::SparseTensor(t) => Type::Tensor {
1325 shape: Some(vec![Some(t.rows), Some(t.cols)]),
1326 },
1327 Value::ComplexTensor(t) => Type::Tensor {
1328 shape: Some(t.shape.iter().map(|&d| Some(d)).collect()),
1329 },
1330 Value::Symbolic(_) => Type::Symbolic,
1331 Value::Cell(cells) => {
1332 if cells.data.is_empty() {
1333 Type::cell()
1334 } else {
1335 let element_type = Type::from_value(&cells.data[0]);
1337 Type::Cell {
1338 element_type: Some(Box::new(element_type)),
1339 length: Some(cells.data.len()),
1340 }
1341 }
1342 }
1343 Value::GpuTensor(h) => Type::Tensor {
1344 shape: Some(h.shape.iter().map(|&d| Some(d)).collect()),
1345 },
1346 Value::Object(_) => Type::Unknown,
1347 Value::HandleObject(_) => Type::Unknown,
1348 Value::Listener(_) => Type::Unknown,
1349 Value::Struct(_) => Type::Struct { known_fields: None },
1350 Value::FunctionHandle(_)
1351 | Value::ExternalFunctionHandle(_)
1352 | Value::MethodFunctionHandle(_)
1353 | Value::BoundFunctionHandle { .. } => Type::Function {
1354 params: vec![Type::Unknown],
1355 returns: Box::new(Type::Unknown),
1356 },
1357 Value::Closure(_) => Type::Function {
1358 params: vec![Type::Unknown],
1359 returns: Box::new(Type::Unknown),
1360 },
1361 Value::ClassRef(_) => Type::Unknown,
1362 Value::MException(_) => Type::Unknown,
1363 Value::CharArray(ca) => {
1364 Type::Cell {
1366 element_type: Some(Box::new(Type::String)),
1367 length: Some(ca.rows * ca.cols),
1368 }
1369 }
1370 Value::OutputList(values) => {
1371 Type::OutputList(values.iter().map(Type::from_value).collect())
1372 }
1373 }
1374 }
1375}
1376
1377#[derive(Debug, Clone, PartialEq)]
1378pub struct Closure {
1379 pub function_name: String,
1380 pub bound_function: Option<usize>,
1381 pub captures: Vec<Value>,
1382}
1383
1384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1386pub enum AccelTag {
1387 Unary,
1388 Elementwise,
1389 Reduction,
1390 MatMul,
1391 Transpose,
1392 ArrayConstruct,
1393}
1394
1395pub type BuiltinControlFlow = runmat_async::RuntimeError;
1397
1398pub type BuiltinFuture = Pin<Box<dyn Future<Output = Result<Value, BuiltinControlFlow>> + 'static>>;
1400
1401#[derive(Clone, Debug, Default)]
1402pub struct ResolveContext {
1403 pub literal_args: Vec<LiteralValue>,
1404}
1405
1406#[derive(Clone, Debug, PartialEq)]
1407pub enum LiteralValue {
1408 Number(f64),
1409 Bool(bool),
1410 String(String),
1411 Vector(Vec<LiteralValue>),
1412 Unknown,
1413}
1414
1415impl ResolveContext {
1416 pub fn new(literal_args: Vec<LiteralValue>) -> Self {
1417 Self { literal_args }
1418 }
1419
1420 pub fn numeric_dims(&self) -> Vec<Option<usize>> {
1421 self.numeric_dims_from(0)
1422 }
1423
1424 pub fn numeric_dims_from(&self, start: usize) -> Vec<Option<usize>> {
1425 let slice = self.literal_args.get(start..).unwrap_or(&[]);
1426 if let Some(LiteralValue::Vector(values)) = slice.first() {
1427 return values
1428 .iter()
1429 .map(Self::numeric_dimension_from_literal)
1430 .collect();
1431 }
1432 slice
1433 .iter()
1434 .map(Self::numeric_dimension_from_literal)
1435 .collect()
1436 }
1437
1438 pub fn literal_string_at(&self, index: usize) -> Option<String> {
1439 match self.literal_args.get(index) {
1440 Some(LiteralValue::String(value)) => Some(value.to_ascii_lowercase()),
1441 _ => None,
1442 }
1443 }
1444
1445 pub fn literal_bool_at(&self, index: usize) -> Option<bool> {
1446 match self.literal_args.get(index) {
1447 Some(LiteralValue::Bool(value)) => Some(*value),
1448 _ => None,
1449 }
1450 }
1451
1452 pub fn literal_vector_at(&self, index: usize) -> Option<Vec<LiteralValue>> {
1453 match self.literal_args.get(index) {
1454 Some(LiteralValue::Vector(values)) => Some(values.clone()),
1455 _ => None,
1456 }
1457 }
1458
1459 pub fn numeric_vector_at(&self, index: usize) -> Option<Vec<Option<usize>>> {
1460 let values = match self.literal_args.get(index) {
1461 Some(LiteralValue::Vector(values)) => values,
1462 _ => return None,
1463 };
1464 if values
1465 .iter()
1466 .any(|value| matches!(value, LiteralValue::Vector(_)))
1467 {
1468 return None;
1469 }
1470 Some(
1471 values
1472 .iter()
1473 .map(Self::numeric_dimension_from_literal)
1474 .collect(),
1475 )
1476 }
1477
1478 fn numeric_dimension_from_literal(value: &LiteralValue) -> Option<usize> {
1479 match value {
1480 LiteralValue::Number(num) => {
1481 if num.is_finite() {
1482 let rounded = num.round();
1483 if (num - rounded).abs() <= 1e-9 && rounded >= 0.0 {
1484 return Some(rounded as usize);
1485 }
1486 }
1487 None
1488 }
1489 _ => None,
1490 }
1491 }
1492}
1493
1494#[cfg(test)]
1495mod resolve_context_tests {
1496 use super::{LiteralValue, ResolveContext};
1497
1498 #[test]
1499 fn numeric_dims_reads_vector_literal() {
1500 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
1501 LiteralValue::Number(2.0),
1502 LiteralValue::Number(3.0),
1503 ])]);
1504 assert_eq!(ctx.numeric_dims(), vec![Some(2), Some(3)]);
1505 }
1506
1507 #[test]
1508 fn numeric_dims_skips_non_numeric_entries() {
1509 let ctx = ResolveContext::new(vec![
1510 LiteralValue::Number(4.0),
1511 LiteralValue::String("like".to_string()),
1512 LiteralValue::Unknown,
1513 ]);
1514 assert_eq!(ctx.numeric_dims(), vec![Some(4), None, None]);
1515 }
1516
1517 #[test]
1518 fn numeric_dims_prefers_vector_even_with_trailing_args() {
1519 let ctx = ResolveContext::new(vec![
1520 LiteralValue::Vector(vec![LiteralValue::Number(1.0), LiteralValue::Number(5.0)]),
1521 LiteralValue::String("like".to_string()),
1522 ]);
1523 assert_eq!(ctx.numeric_dims(), vec![Some(1), Some(5)]);
1524 }
1525
1526 #[test]
1527 fn literal_string_is_lowercased() {
1528 let ctx = ResolveContext::new(vec![LiteralValue::String("OmItNaN".to_string())]);
1529 assert_eq!(ctx.literal_string_at(0), Some("omitnan".to_string()));
1530 }
1531
1532 #[test]
1533 fn literal_bool_is_available() {
1534 let ctx = ResolveContext::new(vec![LiteralValue::Bool(true)]);
1535 assert_eq!(ctx.literal_bool_at(0), Some(true));
1536 }
1537
1538 #[test]
1539 fn literal_vector_at_returns_clone() {
1540 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
1541 LiteralValue::Number(7.0),
1542 LiteralValue::Unknown,
1543 ])]);
1544 assert_eq!(
1545 ctx.literal_vector_at(0),
1546 Some(vec![LiteralValue::Number(7.0), LiteralValue::Unknown])
1547 );
1548 }
1549
1550 #[test]
1551 fn numeric_vector_at_rejects_nested_vectors() {
1552 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![LiteralValue::Vector(
1553 vec![LiteralValue::Number(1.0)],
1554 )])]);
1555 assert_eq!(ctx.numeric_vector_at(0), None);
1556 }
1557}
1558
1559pub type TypeResolver = fn(args: &[Type]) -> Type;
1560pub type TypeResolverWithContext = fn(args: &[Type], ctx: &ResolveContext) -> Type;
1561
1562#[derive(Clone, Copy, Debug)]
1563pub enum TypeResolverKind {
1564 Simple(TypeResolver),
1565 WithContext(TypeResolverWithContext),
1566}
1567
1568pub fn type_resolver_kind(resolver: TypeResolver) -> TypeResolverKind {
1569 TypeResolverKind::Simple(resolver)
1570}
1571
1572pub fn type_resolver_kind_ctx(resolver: TypeResolverWithContext) -> TypeResolverKind {
1573 TypeResolverKind::WithContext(resolver)
1574}
1575
1576#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1577pub enum BuiltinOutputMode {
1578 Fixed,
1579 ByRequestedOutputCount,
1580}
1581
1582#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1583pub enum BuiltinCompletionPolicy {
1584 Public,
1585 MethodOnly,
1586 HiddenInternal,
1587}
1588
1589#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1590pub enum BuiltinParamArity {
1591 Required,
1592 Optional,
1593 Variadic,
1594}
1595
1596#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
1597pub enum BuiltinParamType {
1598 Any,
1599 NumericScalar,
1600 IntegerScalar,
1601 StringScalar,
1602 NumericArray,
1603 LogicalArray,
1604 SizeArg,
1605 LikePrototype,
1606 AxesHandle,
1607 StyleSpec,
1608 PropertyName,
1609 PropertyValue,
1610}
1611
1612#[derive(Debug, Clone, Serialize)]
1613pub struct BuiltinParamDescriptor {
1614 pub name: &'static str,
1615 pub ty: BuiltinParamType,
1616 pub arity: BuiltinParamArity,
1617 pub default: Option<&'static str>,
1618 pub description: &'static str,
1619}
1620
1621#[derive(Debug, Clone, Serialize)]
1622pub struct BuiltinSignatureDescriptor {
1623 pub label: &'static str,
1624 pub inputs: &'static [BuiltinParamDescriptor],
1625 pub outputs: &'static [BuiltinParamDescriptor],
1626}
1627
1628#[derive(Debug, Clone, Serialize)]
1629pub struct BuiltinErrorDescriptor {
1630 pub code: &'static str,
1631 pub identifier: Option<&'static str>,
1632 pub when: &'static str,
1633 pub message: &'static str,
1634}
1635
1636#[derive(Debug, Clone, Serialize)]
1637pub struct BuiltinDescriptor {
1638 pub signatures: &'static [BuiltinSignatureDescriptor],
1639 pub output_mode: BuiltinOutputMode,
1640 pub completion_policy: BuiltinCompletionPolicy,
1641 pub errors: &'static [BuiltinErrorDescriptor],
1642}
1643
1644#[derive(Debug, Clone)]
1646pub struct BuiltinFunction {
1647 pub name: &'static str,
1648 pub description: &'static str,
1649 pub category: &'static str,
1650 pub doc: &'static str,
1651 pub examples: &'static str,
1652 pub param_types: Vec<Type>,
1653 pub return_type: Type,
1654 pub type_resolver: Option<TypeResolverKind>,
1655 pub implementation: fn(&[Value]) -> BuiltinFuture,
1656 pub accel_tags: &'static [AccelTag],
1657 pub is_sink: bool,
1658 pub suppress_auto_output: bool,
1659 pub descriptor: Option<&'static BuiltinDescriptor>,
1660}
1661
1662impl BuiltinFunction {
1663 #[allow(clippy::too_many_arguments)]
1664 pub fn new(
1665 name: &'static str,
1666 description: &'static str,
1667 category: &'static str,
1668 doc: &'static str,
1669 examples: &'static str,
1670 param_types: Vec<Type>,
1671 return_type: Type,
1672 type_resolver: Option<TypeResolverKind>,
1673 implementation: fn(&[Value]) -> BuiltinFuture,
1674 accel_tags: &'static [AccelTag],
1675 is_sink: bool,
1676 suppress_auto_output: bool,
1677 ) -> Self {
1678 Self {
1679 name,
1680 description,
1681 category,
1682 doc,
1683 examples,
1684 param_types,
1685 return_type,
1686 type_resolver,
1687 implementation,
1688 accel_tags,
1689 is_sink,
1690 suppress_auto_output,
1691 descriptor: None,
1692 }
1693 }
1694
1695 pub fn with_descriptor(mut self, descriptor: &'static BuiltinDescriptor) -> Self {
1696 self.descriptor = Some(descriptor);
1697 self
1698 }
1699
1700 pub fn with_descriptor_option(
1701 mut self,
1702 descriptor: Option<&'static BuiltinDescriptor>,
1703 ) -> Self {
1704 self.descriptor = descriptor;
1705 self
1706 }
1707
1708 pub fn infer_return_type(&self, args: &[Type]) -> Type {
1709 self.infer_return_type_with_context(args, &ResolveContext::default())
1710 }
1711
1712 pub fn infer_return_type_with_context(&self, args: &[Type], ctx: &ResolveContext) -> Type {
1713 if let Some(resolver) = self.type_resolver {
1714 return match resolver {
1715 TypeResolverKind::Simple(resolver) => resolver(args),
1716 TypeResolverKind::WithContext(resolver) => resolver(args, ctx),
1717 };
1718 }
1719 self.return_type.clone()
1720 }
1721
1722 pub fn semantics(&self) -> BuiltinSemantics {
1723 semantics::builtin_semantics_for(self)
1724 }
1725}
1726
1727#[derive(Clone)]
1729pub struct Constant {
1730 pub name: &'static str,
1731 pub value: Value,
1732}
1733
1734pub mod semantics;
1735pub mod shape_rules;
1736
1737pub use semantics::{
1738 builtin_semantics_for, builtin_semantics_for_name, BuiltinAsyncBehavior, BuiltinCompatibility,
1739 BuiltinEffects, BuiltinEnvironmentEffect, BuiltinPurity, BuiltinSemanticKind, BuiltinSemantics,
1740 BuiltinWorkspaceEffect, ConcatKind, ShapeTransformKind,
1741};
1742
1743impl std::fmt::Debug for Constant {
1744 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1745 write!(
1746 f,
1747 "Constant {{ name: {:?}, value: {:?} }}",
1748 self.name, self.value
1749 )
1750 }
1751}
1752
1753#[cfg(not(target_arch = "wasm32"))]
1754inventory::collect!(BuiltinFunction);
1755#[cfg(not(target_arch = "wasm32"))]
1756inventory::collect!(Constant);
1757
1758#[cfg(not(target_arch = "wasm32"))]
1759pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
1760 inventory::iter::<BuiltinFunction>().collect()
1761}
1762
1763#[cfg(target_arch = "wasm32")]
1764pub fn builtin_functions() -> Vec<&'static BuiltinFunction> {
1765 wasm_registry::builtin_functions()
1766}
1767
1768#[cfg(not(target_arch = "wasm32"))]
1769static BUILTIN_LOOKUP: OnceLock<HashMap<String, &'static BuiltinFunction>> = OnceLock::new();
1770
1771#[cfg(not(target_arch = "wasm32"))]
1772fn builtin_lookup_map() -> &'static HashMap<String, &'static BuiltinFunction> {
1773 BUILTIN_LOOKUP.get_or_init(|| {
1774 let mut map = HashMap::new();
1775 for func in builtin_functions() {
1776 map.insert(func.name.to_ascii_lowercase(), func);
1777 }
1778 map
1779 })
1780}
1781
1782#[cfg(not(target_arch = "wasm32"))]
1783pub fn builtin_function_by_name(name: &str) -> Option<&'static BuiltinFunction> {
1784 builtin_lookup_map()
1785 .get(&name.to_ascii_lowercase())
1786 .copied()
1787}
1788
1789#[cfg(target_arch = "wasm32")]
1790pub fn builtin_function_by_name(name: &str) -> Option<&'static BuiltinFunction> {
1791 wasm_registry::builtin_functions()
1792 .into_iter()
1793 .find(|f| f.name.eq_ignore_ascii_case(name))
1794}
1795
1796pub fn suppresses_auto_output(name: &str) -> bool {
1797 builtin_function_by_name(name)
1798 .map(|f| f.suppress_auto_output)
1799 .unwrap_or(false)
1800}
1801
1802#[cfg(not(target_arch = "wasm32"))]
1803pub fn constants() -> Vec<&'static Constant> {
1804 inventory::iter::<Constant>().collect()
1805}
1806
1807#[cfg(target_arch = "wasm32")]
1808pub fn constants() -> Vec<&'static Constant> {
1809 wasm_registry::constants()
1810}
1811
1812#[derive(Debug)]
1817pub struct BuiltinDoc {
1818 pub name: &'static str,
1819 pub category: Option<&'static str>,
1820 pub summary: Option<&'static str>,
1821 pub keywords: Option<&'static str>,
1822 pub errors: Option<&'static str>,
1823 pub related: Option<&'static str>,
1824 pub introduced: Option<&'static str>,
1825 pub status: Option<&'static str>,
1826 pub examples: Option<&'static str>,
1827}
1828
1829#[cfg(not(target_arch = "wasm32"))]
1830inventory::collect!(BuiltinDoc);
1831
1832#[cfg(not(target_arch = "wasm32"))]
1833pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
1834 inventory::iter::<BuiltinDoc>().collect()
1835}
1836
1837#[cfg(target_arch = "wasm32")]
1838pub fn builtin_docs() -> Vec<&'static BuiltinDoc> {
1839 wasm_registry::builtin_docs()
1840}
1841
1842#[derive(Debug, Clone, Copy, PartialEq, Default)]
1848pub enum FormatMode {
1849 #[default]
1851 Short,
1852 Long,
1854 ShortE,
1856 LongE,
1858 ShortG,
1860 LongG,
1862 Rational,
1864 Hex,
1866}
1867
1868runmat_thread_local! {
1869 static DISPLAY_FORMAT: RefCell<FormatMode> = const { RefCell::new(FormatMode::Short) };
1870}
1871
1872pub fn set_display_format(mode: FormatMode) {
1873 DISPLAY_FORMAT.with(|c| *c.borrow_mut() = mode);
1874}
1875
1876pub fn get_display_format() -> FormatMode {
1877 DISPLAY_FORMAT.with(|c| *c.borrow())
1878}
1879
1880pub fn format_number(value: f64) -> String {
1882 if value.is_nan() {
1883 return "NaN".to_string();
1884 }
1885 if value.is_infinite() {
1886 return if value.is_sign_negative() {
1887 "-Inf"
1888 } else {
1889 "Inf"
1890 }
1891 .to_string();
1892 }
1893 let mode = get_display_format();
1894 if mode == FormatMode::Hex {
1895 return fmt_hex(value);
1896 }
1897 let v = if value == 0.0 { 0.0 } else { value };
1898 match mode {
1899 FormatMode::Short => fmt_short(v),
1900 FormatMode::Long => fmt_long(v),
1901 FormatMode::ShortE => fmt_sci(v, 4),
1902 FormatMode::LongE => fmt_sci(v, 14),
1903 FormatMode::ShortG => fmt_compact(v, 5),
1904 FormatMode::LongG => fmt_compact(v, 15),
1905 FormatMode::Rational => fmt_rational(v),
1906 FormatMode::Hex => unreachable!("hex mode handled before zero normalization"),
1907 }
1908}
1909
1910fn matlab_exp(s: &str) -> String {
1912 if let Some(e_pos) = s.find('e') {
1913 let mantissa = &s[..e_pos];
1914 let exp: i32 = s[e_pos + 1..].parse().unwrap_or(0);
1915 let sign = if exp >= 0 { '+' } else { '-' };
1916 format!("{mantissa}e{sign}{:02}", exp.unsigned_abs())
1917 } else {
1918 s.to_string()
1919 }
1920}
1921
1922fn fmt_sci(v: f64, dec: usize) -> String {
1923 if v == 0.0 {
1924 return format!("0.{:0>dec$}e+00", 0, dec = dec);
1925 }
1926 let s = format!("{v:.dec$e}");
1927 matlab_exp(&s)
1928}
1929
1930fn fmt_short(v: f64) -> String {
1931 let abs = v.abs();
1932 if abs == 0.0 {
1933 return "0".to_string();
1934 }
1935 if v.fract() == 0.0 && abs < 1e15 {
1936 return format!("{}", v as i64);
1937 }
1938 if (0.001..10000.0).contains(&abs) {
1939 format!("{:.4}", v)
1940 } else {
1941 fmt_sci(v, 4)
1942 }
1943}
1944
1945fn fmt_long(v: f64) -> String {
1946 let abs = v.abs();
1947 if abs == 0.0 {
1948 return "0".to_string();
1949 }
1950 if v.fract() == 0.0 && abs < 1e15 {
1951 return format!("{}", v as i64);
1952 }
1953 if (0.001..10000.0).contains(&abs) {
1954 format!("{:.15}", v)
1955 } else {
1956 fmt_sci(v, 14)
1957 }
1958}
1959
1960fn fmt_compact(v: f64, sig_digits: usize) -> String {
1961 let abs = v.abs();
1962 if abs == 0.0 {
1963 return "0".to_string();
1964 }
1965 let use_scientific = !(1e-4..1e6).contains(&abs);
1966 if use_scientific {
1967 let dec = sig_digits - 1;
1968 let s = format!("{v:.dec$e}");
1969 if let Some(e_pos) = s.find('e') {
1971 let exp_part = &s[e_pos..];
1972 let mut mantissa = s[..e_pos].to_string();
1973 if let Some(dot) = mantissa.find('.') {
1974 let mut end = mantissa.len();
1975 while end > dot + 1 && mantissa.as_bytes()[end - 1] == b'0' {
1976 end -= 1;
1977 }
1978 if mantissa.as_bytes()[end - 1] == b'.' {
1979 end -= 1;
1980 }
1981 mantissa.truncate(end);
1982 }
1983 return matlab_exp(&format!("{mantissa}{exp_part}"));
1984 }
1985 return matlab_exp(&s);
1986 }
1987 let exp10 = abs.log10().floor() as i32;
1988 let decimals = ((sig_digits as i32 - 1 - exp10).max(0)) as usize;
1989 let pow = 10f64.powi(decimals as i32);
1990 let rounded = (v * pow).round() / pow;
1991 let mut s = format!("{rounded:.decimals$}");
1992 if let Some(dot) = s.find('.') {
1993 let mut end = s.len();
1994 while end > dot + 1 && s.as_bytes()[end - 1] == b'0' {
1995 end -= 1;
1996 }
1997 if s.as_bytes()[end - 1] == b'.' {
1998 end -= 1;
1999 }
2000 s.truncate(end);
2001 }
2002 if s.is_empty() || s == "-0" {
2003 s = "0".to_string();
2004 }
2005 s
2006}
2007
2008fn fmt_rational(v: f64) -> String {
2009 if v == 0.0 {
2010 return "0".to_string();
2011 }
2012 let negative = v < 0.0;
2013 let abs = v.abs();
2014 if v.fract() == 0.0 && abs < 1e15 {
2015 return format!("{}", v as i64);
2016 }
2017 let tol = 5e-7 * abs;
2020 let max_d = 1_000_000i64;
2021 let mut n0: i64 = 1;
2022 let mut n1: i64 = abs.floor() as i64;
2023 let mut d0: i64 = 0;
2024 let mut d1: i64 = 1;
2025 let mut a = abs;
2026 let mut best_n = n1;
2027 let mut best_d = d1;
2028 for _ in 0..50 {
2029 if (abs - best_n as f64 / best_d as f64).abs() <= tol {
2030 break;
2031 }
2032 let f = a.fract();
2033 if f < 1e-10 {
2034 break;
2035 }
2036 a = 1.0 / f;
2037 let q = a.floor() as i64;
2038 let Some(n2) = q.checked_mul(n1).and_then(|v| v.checked_add(n0)) else {
2039 break;
2040 };
2041 let Some(d2) = q.checked_mul(d1).and_then(|v| v.checked_add(d0)) else {
2042 break;
2043 };
2044 if d2 > max_d {
2045 break;
2046 }
2047 best_n = n2;
2048 best_d = d2;
2049 n0 = n1;
2050 n1 = n2;
2051 d0 = d1;
2052 d1 = d2;
2053 }
2054 let sign = if negative { "-" } else { "" };
2055 if best_d == 1 {
2056 format!("{sign}{best_n}")
2057 } else {
2058 format!("{sign}{best_n}/{best_d}")
2059 }
2060}
2061
2062fn fmt_hex(v: f64) -> String {
2063 format!("{:016x}", v.to_bits())
2064}
2065
2066#[derive(Debug, Clone, PartialEq)]
2068pub struct MException {
2069 pub identifier: String,
2070 pub message: String,
2071 pub stack: Vec<String>,
2072}
2073
2074impl MException {
2075 pub fn new(identifier: String, message: String) -> Self {
2076 Self {
2077 identifier,
2078 message,
2079 stack: Vec::new(),
2080 }
2081 }
2082}
2083
2084#[derive(Debug, Clone)]
2086pub struct HandleRef {
2087 pub class_name: String,
2088 pub target: GcPtr<Value>,
2089 pub valid: bool,
2090}
2091
2092const HANDLE_VALID_FLAG_PROPERTY: &str = "__runmat_handle_valid__";
2093
2094pub fn is_handle_valid(handle: &HandleRef) -> bool {
2095 if !handle.valid {
2096 return false;
2097 }
2098 let raw = unsafe { handle.target.as_raw() };
2099 if raw.is_null() {
2100 return false;
2101 }
2102 match unsafe { &*raw } {
2103 Value::Object(obj) => !matches!(
2104 obj.properties.get(HANDLE_VALID_FLAG_PROPERTY),
2105 Some(Value::Bool(false))
2106 ),
2107 _ => true,
2108 }
2109}
2110
2111pub fn set_handle_valid(handle: &HandleRef, valid: bool) -> bool {
2112 let raw = unsafe { handle.target.as_raw_mut() };
2113 if raw.is_null() {
2114 return false;
2115 }
2116 match unsafe { &mut *raw } {
2117 Value::Object(obj) => {
2118 obj.properties
2119 .insert(HANDLE_VALID_FLAG_PROPERTY.to_string(), Value::Bool(valid));
2120 true
2121 }
2122 _ => false,
2123 }
2124}
2125
2126impl PartialEq for HandleRef {
2127 fn eq(&self, other: &Self) -> bool {
2128 let a = unsafe { self.target.as_raw() } as usize;
2129 let b = unsafe { other.target.as_raw() } as usize;
2130 a == b
2131 }
2132}
2133
2134#[derive(Debug, Clone, PartialEq)]
2136pub struct Listener {
2137 pub id: u64,
2138 pub target: GcPtr<Value>,
2139 pub event_name: String,
2140 pub callback: GcPtr<Value>,
2141 pub enabled: bool,
2142 pub valid: bool,
2143}
2144
2145impl Listener {
2146 pub fn class_name(&self) -> String {
2147 match unsafe { &*self.target.as_raw() } {
2148 Value::Object(o) => o.class_name.clone(),
2149 Value::HandleObject(h) => h.class_name.clone(),
2150 _ => String::new(),
2151 }
2152 }
2153}
2154
2155impl fmt::Display for Value {
2156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2157 match self {
2158 Value::Int(i) => write!(f, "{}", i.to_i64()),
2159 Value::Num(n) => write!(f, "{}", format_number(*n)),
2160 Value::Complex(re, im) => {
2161 if *im == 0.0 {
2162 write!(f, "{}", format_number(*re))
2163 } else if *re == 0.0 {
2164 write!(f, "{}i", format_number(*im))
2165 } else if *im < 0.0 {
2166 write!(f, "{}-{}i", format_number(*re), format_number(im.abs()))
2167 } else {
2168 write!(f, "{}+{}i", format_number(*re), format_number(*im))
2169 }
2170 }
2171 Value::Bool(b) => write!(f, "{}", if *b { 1 } else { 0 }),
2172 Value::LogicalArray(la) => write!(f, "{la}"),
2173 Value::String(s) => write!(f, "'{s}'"),
2174 Value::StringArray(sa) => write!(f, "{sa}"),
2175 Value::CharArray(ca) => write!(f, "{ca}"),
2176 Value::Tensor(m) => write!(f, "{m}"),
2177 Value::SparseTensor(m) => write!(f, "{m}"),
2178 Value::ComplexTensor(m) => write!(f, "{m}"),
2179 Value::Symbolic(expr) => write!(f, "{expr}"),
2180 Value::Cell(ca) => ca.fmt(f),
2181
2182 Value::GpuTensor(h) => write!(
2183 f,
2184 "GpuTensor(shape={:?}, device={}, buffer={})",
2185 h.shape, h.device_id, h.buffer_id
2186 ),
2187 Value::Object(obj) => write!(f, "{}(props={})", obj.class_name, obj.properties.len()),
2188 Value::HandleObject(h) => {
2189 let ptr = unsafe { h.target.as_raw() } as usize;
2190 write!(
2191 f,
2192 "<handle {} @0x{:x} valid={}>",
2193 h.class_name, ptr, h.valid
2194 )
2195 }
2196 Value::Listener(l) => {
2197 let ptr = unsafe { l.target.as_raw() } as usize;
2198 write!(
2199 f,
2200 "<listener id={} {}@0x{:x} '{}' enabled={} valid={}>",
2201 l.id,
2202 l.class_name(),
2203 ptr,
2204 l.event_name,
2205 l.enabled,
2206 l.valid
2207 )
2208 }
2209 Value::Struct(st) => {
2210 write!(f, "struct {{")?;
2211 for (i, (key, val)) in st.fields.iter().enumerate() {
2212 if i > 0 {
2213 write!(f, ", ")?;
2214 }
2215 write!(f, "{}: {}", key, val)?;
2216 }
2217 write!(f, "}}")
2218 }
2219 Value::OutputList(values) => {
2220 write!(f, "[")?;
2221 for (i, value) in values.iter().enumerate() {
2222 if i > 0 {
2223 write!(f, ", ")?;
2224 }
2225 write!(f, "{}", value)?;
2226 }
2227 write!(f, "]")
2228 }
2229 Value::FunctionHandle(name)
2230 | Value::ExternalFunctionHandle(name)
2231 | Value::MethodFunctionHandle(name) => {
2232 write!(f, "@{name}")
2233 }
2234 Value::BoundFunctionHandle { name, .. } => write!(f, "@{name}"),
2235 Value::Closure(c) => write!(
2236 f,
2237 "<closure {} captures={}>",
2238 c.function_name,
2239 c.captures.len()
2240 ),
2241 Value::ClassRef(name) => write!(f, "<class {name}>"),
2242 Value::MException(e) => write!(
2243 f,
2244 "MException(identifier='{}', message='{}')",
2245 e.identifier, e.message
2246 ),
2247 }
2248 }
2249}
2250
2251impl fmt::Display for ComplexTensor {
2252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2253 match self.shape.len() {
2254 0 | 1 => {
2255 write!(f, "[")?;
2256 for (i, (re, im)) in self.data.iter().enumerate() {
2257 if i > 0 {
2258 write!(f, " ")?;
2259 }
2260 let s = Value::Complex(*re, *im).to_string();
2261 write!(f, "{s}")?;
2262 }
2263 write!(f, "]")
2264 }
2265 2 => {
2266 let rows = self.rows;
2267 let cols = self.cols;
2268 write!(f, "[")?;
2269 for r in 0..rows {
2270 for c in 0..cols {
2271 if c > 0 {
2272 write!(f, " ")?;
2273 }
2274 let (re, im) = self.data[r + c * rows];
2275 let s = Value::Complex(re, im).to_string();
2276 write!(f, "{s}")?;
2277 }
2278 if r + 1 < rows {
2279 write!(f, "; ")?;
2280 }
2281 }
2282 write!(f, "]")
2283 }
2284 _ => {
2285 if should_expand_nd_display(&self.shape) {
2286 write_nd_pages(f, &self.shape, |f, idx| {
2287 let (re, im) = self.data[idx];
2288 write!(f, "{}", Value::Complex(re, im))
2289 })
2290 } else {
2291 write!(f, "ComplexTensor(shape={:?})", self.shape)
2292 }
2293 }
2294 }
2295 }
2296}
2297
2298#[cfg(test)]
2299mod display_tests {
2300 use super::{
2301 fmt_rational, format_number, set_display_format, ComplexTensor, FormatMode, LogicalArray,
2302 Tensor,
2303 };
2304
2305 #[test]
2306 fn fmt_rational_large_value_with_tiny_fract_does_not_overflow() {
2307 let result = std::panic::catch_unwind(|| fmt_rational(1_000_000_000_000_000.000_1));
2310 assert!(
2311 result.is_ok(),
2312 "fmt_rational panicked on large value with tiny fract"
2313 );
2314
2315 let result = std::panic::catch_unwind(|| fmt_rational(-1_000_000_000_000_000.000_1));
2317 assert!(
2318 result.is_ok(),
2319 "fmt_rational panicked on negative large value with tiny fract"
2320 );
2321 }
2322
2323 #[test]
2324 fn tensor_nd_display_uses_page_headers() {
2325 let tensor = Tensor::new(
2326 vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
2327 vec![2, 3, 2],
2328 )
2329 .expect("tensor");
2330 let rendered = tensor.to_string();
2331 assert!(rendered.contains("(:, :, 1) ="));
2332 assert!(rendered.contains("(:, :, 2) ="));
2333 assert!(rendered.contains(" 1 0 0"));
2334 }
2335
2336 #[test]
2337 fn tensor_nd_display_falls_back_for_large_arrays() {
2338 let tensor = Tensor::new(vec![0.0; 4097], vec![1, 1, 4097]).expect("tensor");
2339 assert_eq!(tensor.to_string(), "Tensor(shape=[1, 1, 4097])");
2340 }
2341
2342 #[test]
2343 fn logical_nd_display_uses_headers_and_fallback_summary() {
2344 let logical =
2345 LogicalArray::new(vec![1, 0, 0, 1, 1, 0, 0, 1], vec![2, 2, 2]).expect("logical");
2346 let rendered = logical.to_string();
2347 assert!(rendered.contains("(:, :, 1) ="));
2348 assert!(rendered.contains("(:, :, 2) ="));
2349
2350 let large = LogicalArray::new(vec![1; 4097], vec![1, 1, 4097]).expect("large logical");
2351 assert_eq!(large.to_string(), "1x1x4097 logical array");
2352 }
2353
2354 #[test]
2355 fn complex_nd_display_uses_page_headers() {
2356 let complex = ComplexTensor::new(
2357 vec![(1.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0)],
2358 vec![2, 1, 2],
2359 )
2360 .expect("complex");
2361 let rendered = complex.to_string();
2362 assert!(rendered.contains("(:, :, 1) ="));
2363 assert!(rendered.contains("(:, :, 2) ="));
2364 }
2365
2366 #[test]
2367 fn format_hex_preserves_negative_zero_sign_bit() {
2368 set_display_format(FormatMode::Hex);
2369 assert_eq!(format_number(-0.0), "8000000000000000");
2370 assert_eq!(format_number(0.0), "0000000000000000");
2371 set_display_format(FormatMode::Short);
2372 }
2373}
2374
2375#[derive(Debug, Clone, PartialEq)]
2376pub struct CellArray {
2377 pub data: Vec<GcPtr<Value>>,
2378 pub shape: Vec<usize>,
2380 pub rows: usize,
2382 pub cols: usize,
2384}
2385
2386impl CellArray {
2387 pub fn new_handles(
2388 handles: Vec<GcPtr<Value>>,
2389 rows: usize,
2390 cols: usize,
2391 ) -> Result<Self, String> {
2392 Self::new_handles_with_shape(handles, vec![rows, cols])
2393 }
2394
2395 pub fn new_handles_with_shape(
2396 handles: Vec<GcPtr<Value>>,
2397 shape: Vec<usize>,
2398 ) -> Result<Self, String> {
2399 let expected = total_len(&shape)
2400 .ok_or_else(|| "Cell data shape exceeds platform limits".to_string())?;
2401 if expected != handles.len() {
2402 return Err(format!(
2403 "Cell data length {} doesn't match shape {:?} ({} elements)",
2404 handles.len(),
2405 shape,
2406 expected
2407 ));
2408 }
2409 let (rows, cols) = shape_rows_cols(&shape);
2410 Ok(CellArray {
2411 data: handles,
2412 shape,
2413 rows,
2414 cols,
2415 })
2416 }
2417
2418 pub fn new(data: Vec<Value>, rows: usize, cols: usize) -> Result<Self, String> {
2419 Self::new_with_shape(data, vec![rows, cols])
2420 }
2421
2422 pub fn new_with_shape(data: Vec<Value>, shape: Vec<usize>) -> Result<Self, String> {
2423 let expected = total_len(&shape)
2424 .ok_or_else(|| "Cell data shape exceeds platform limits".to_string())?;
2425 if expected != data.len() {
2426 return Err(format!(
2427 "Cell data length {} doesn't match shape {:?} ({} elements)",
2428 data.len(),
2429 shape,
2430 expected
2431 ));
2432 }
2433 let handles: Vec<GcPtr<Value>> = data
2435 .into_iter()
2436 .map(|v| unsafe { GcPtr::from_raw(Box::into_raw(Box::new(v))) })
2437 .collect();
2438 Self::new_handles_with_shape(handles, shape)
2439 }
2440
2441 pub fn get(&self, row: usize, col: usize) -> Result<Value, String> {
2442 if row >= self.rows || col >= self.cols {
2443 return Err(format!(
2444 "Cell index ({row}, {col}) out of bounds for {}x{} cell array",
2445 self.rows, self.cols
2446 ));
2447 }
2448 Ok((*self.data[row * self.cols + col]).clone())
2449 }
2450}
2451
2452fn total_len(shape: &[usize]) -> Option<usize> {
2453 if shape.is_empty() {
2454 return Some(0);
2455 }
2456 shape
2457 .iter()
2458 .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
2459}
2460
2461fn shape_rows_cols(shape: &[usize]) -> (usize, usize) {
2462 if shape.is_empty() {
2463 return (0, 0);
2464 }
2465 if shape.len() == 1 {
2466 return (1, shape[0]);
2467 }
2468 (shape[0], shape[1])
2469}
2470
2471impl fmt::Display for CellArray {
2472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2473 let dims: Vec<String> = self.shape.iter().map(|d| d.to_string()).collect();
2474 if self.shape.len() > 2 {
2475 return write!(f, "{} cell array", dims.join("x"));
2476 }
2477 write!(f, "{}x{} cell array", self.rows, self.cols)?;
2478 if self.rows == 0 || self.cols == 0 {
2479 return Ok(());
2480 }
2481 for r in 0..self.rows {
2482 writeln!(f)?;
2483 write!(f, " ")?;
2484 for c in 0..self.cols {
2485 if c > 0 {
2486 write!(f, " ")?;
2487 }
2488 let value = self.get(r, c).unwrap_or_else(|_| Value::Num(f64::NAN));
2489 write!(f, "{{{value}}}")?;
2490 }
2491 }
2492 Ok(())
2493 }
2494}
2495
2496#[derive(Debug, Clone, PartialEq)]
2497pub struct ObjectInstance {
2498 pub class_name: String,
2499 pub properties: HashMap<String, Value>,
2500}
2501
2502impl ObjectInstance {
2503 pub fn new(class_name: String) -> Self {
2504 Self {
2505 class_name,
2506 properties: HashMap::new(),
2507 }
2508 }
2509
2510 pub fn is_class(&self, name: &str) -> bool {
2511 self.class_name == name
2512 }
2513}
2514
2515#[derive(Debug, Clone, PartialEq, Eq)]
2517pub enum Access {
2518 Public,
2519 Private,
2520 Protected,
2521}
2522
2523#[derive(Debug, Clone)]
2524pub struct PropertyDef {
2525 pub name: String,
2526 pub is_static: bool,
2527 pub is_constant: bool,
2528 pub is_dependent: bool,
2529 pub get_access: Access,
2530 pub set_access: Access,
2531 pub default_value: Option<Value>,
2532}
2533
2534#[derive(Debug, Clone)]
2535pub struct MethodDef {
2536 pub name: String,
2537 pub is_static: bool,
2538 pub is_abstract: bool,
2539 pub is_sealed: bool,
2540 pub access: Access,
2541 pub function_name: String, pub implicit_class_argument: Option<String>,
2543}
2544
2545#[derive(Debug, Clone)]
2546pub struct ClassDef {
2547 pub name: String, pub parent: Option<String>,
2549 pub properties: HashMap<String, PropertyDef>,
2550 pub methods: HashMap<String, MethodDef>,
2551}
2552
2553use std::sync::Mutex;
2554
2555static CLASS_REGISTRY: OnceLock<Mutex<HashMap<String, ClassDef>>> = OnceLock::new();
2556static SEALED_CLASS_REGISTRY: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
2557static ABSTRACT_CLASS_REGISTRY: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
2558static STATIC_VALUES: OnceLock<Mutex<HashMap<(String, String), Value>>> = OnceLock::new();
2559static ENUMERATION_REGISTRY: OnceLock<Mutex<HashMap<String, HashSet<String>>>> = OnceLock::new();
2560
2561fn registry() -> &'static Mutex<HashMap<String, ClassDef>> {
2562 CLASS_REGISTRY.get_or_init(|| Mutex::new(primitive_class_registry()))
2563}
2564
2565fn sealed_registry() -> &'static Mutex<HashSet<String>> {
2566 SEALED_CLASS_REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
2567}
2568
2569fn abstract_registry() -> &'static Mutex<HashSet<String>> {
2570 ABSTRACT_CLASS_REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
2571}
2572
2573fn enumeration_registry() -> &'static Mutex<HashMap<String, HashSet<String>>> {
2574 ENUMERATION_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
2575}
2576
2577fn primitive_class_registry() -> HashMap<String, ClassDef> {
2578 ["double", "single", "logical"]
2579 .into_iter()
2580 .map(|class_name| {
2581 let mut methods = HashMap::new();
2582 methods.insert(
2583 "zeros".to_string(),
2584 MethodDef {
2585 name: "zeros".to_string(),
2586 is_static: true,
2587 is_abstract: false,
2588 is_sealed: false,
2589 access: Access::Public,
2590 function_name: "zeros".to_string(),
2591 implicit_class_argument: Some(class_name.to_string()),
2592 },
2593 );
2594 (
2595 class_name.to_string(),
2596 ClassDef {
2597 name: class_name.to_string(),
2598 parent: None,
2599 properties: HashMap::new(),
2600 methods,
2601 },
2602 )
2603 })
2604 .collect()
2605}
2606
2607pub fn register_class(def: ClassDef) {
2608 register_class_with_modifiers(def, false, false);
2609}
2610
2611pub fn register_class_with_sealed(def: ClassDef, is_sealed: bool) {
2612 register_class_with_modifiers(def, is_sealed, false);
2613}
2614
2615pub fn register_class_with_modifiers(def: ClassDef, is_sealed: bool, is_abstract: bool) {
2616 let mut m = registry().lock().unwrap();
2617 let class_name = def.name.clone();
2618 m.insert(class_name.clone(), def);
2619 let mut sealed = sealed_registry().lock().unwrap();
2620 if is_sealed {
2621 sealed.insert(class_name.clone());
2622 } else {
2623 sealed.remove(&class_name);
2624 }
2625 let mut abstract_classes = abstract_registry().lock().unwrap();
2626 if is_abstract {
2627 abstract_classes.insert(class_name.clone());
2628 } else {
2629 abstract_classes.remove(&class_name);
2630 }
2631 enumeration_registry()
2632 .lock()
2633 .unwrap()
2634 .entry(class_name)
2635 .or_default();
2636}
2637
2638pub fn register_class_enumerations(class_name: &str, members: impl IntoIterator<Item = String>) {
2639 let mut registry = enumeration_registry().lock().unwrap();
2640 let entry = registry.entry(class_name.to_string()).or_default();
2641 entry.clear();
2642 entry.extend(members);
2643}
2644
2645pub fn class_has_enumeration_member(class_name: &str, member: &str) -> bool {
2646 enumeration_registry()
2647 .lock()
2648 .unwrap()
2649 .get(class_name)
2650 .is_some_and(|members| members.contains(member))
2651}
2652
2653pub fn get_class(name: &str) -> Option<ClassDef> {
2654 registry().lock().unwrap().get(name).cloned()
2655}
2656
2657pub fn class_names() -> Vec<String> {
2658 registry().lock().unwrap().keys().cloned().collect()
2659}
2660
2661pub fn is_class_sealed(name: &str) -> bool {
2662 sealed_registry().lock().unwrap().contains(name)
2663}
2664
2665pub fn is_class_abstract(name: &str) -> bool {
2666 abstract_registry().lock().unwrap().contains(name)
2667}
2668
2669pub fn is_class_or_subclass(class_name: &str, ancestor_name: &str) -> bool {
2670 if class_name == ancestor_name {
2671 return true;
2672 }
2673 let reg = registry().lock().unwrap();
2674 let mut current = Some(class_name.to_string());
2675 let mut visited = std::collections::HashSet::new();
2676 while let Some(name) = current {
2677 if !visited.insert(name.clone()) {
2678 break;
2679 }
2680 if name == ancestor_name {
2681 return true;
2682 }
2683 current = reg
2684 .get(&name)
2685 .and_then(|class_def| class_def.parent.clone());
2686 }
2687 false
2688}
2689
2690pub fn lookup_property(class_name: &str, prop: &str) -> Option<(PropertyDef, String)> {
2693 let reg = registry().lock().unwrap();
2694 let mut current = Some(class_name.to_string());
2695 let mut visited = std::collections::HashSet::new();
2696 while let Some(name) = current {
2697 if !visited.insert(name.clone()) {
2698 break;
2699 }
2700 if let Some(cls) = reg.get(&name) {
2701 if let Some(p) = cls.properties.get(prop) {
2702 return Some((p.clone(), name));
2703 }
2704 current = cls.parent.clone();
2705 } else {
2706 break;
2707 }
2708 }
2709 None
2710}
2711
2712pub fn lookup_method(class_name: &str, method: &str) -> Option<(MethodDef, String)> {
2715 let reg = registry().lock().unwrap();
2716 let mut current = Some(class_name.to_string());
2717 let mut visited = std::collections::HashSet::new();
2718 while let Some(name) = current {
2719 if !visited.insert(name.clone()) {
2720 break;
2721 }
2722 if let Some(cls) = reg.get(&name) {
2723 if let Some(m) = cls.methods.get(method) {
2724 return Some((m.clone(), name));
2725 }
2726 current = cls.parent.clone();
2727 } else {
2728 break;
2729 }
2730 }
2731 None
2732}
2733
2734fn static_values() -> &'static Mutex<HashMap<(String, String), Value>> {
2735 STATIC_VALUES.get_or_init(|| Mutex::new(HashMap::new()))
2736}
2737
2738pub fn get_static_property_value(class_name: &str, prop: &str) -> Option<Value> {
2739 static_values()
2740 .lock()
2741 .unwrap()
2742 .get(&(class_name.to_string(), prop.to_string()))
2743 .cloned()
2744}
2745
2746pub fn set_static_property_value(class_name: &str, prop: &str, value: Value) {
2747 static_values()
2748 .lock()
2749 .unwrap()
2750 .insert((class_name.to_string(), prop.to_string()), value);
2751}
2752
2753pub fn set_static_property_value_in_owner(
2755 class_name: &str,
2756 prop: &str,
2757 value: Value,
2758) -> Result<(), String> {
2759 if let Some((_p, owner)) = lookup_property(class_name, prop) {
2760 set_static_property_value(&owner, prop, value);
2761 Ok(())
2762 } else {
2763 Err(format!("Unknown static property '{class_name}.{prop}'"))
2764 }
2765}
2766
2767#[cfg(test)]
2768mod class_registry_tests {
2769 use super::{
2770 get_class, lookup_method, lookup_property, register_class, Access, ClassDef, MethodDef,
2771 PropertyDef,
2772 };
2773 use std::collections::HashMap;
2774 use std::sync::atomic::{AtomicU64, Ordering};
2775
2776 static TEST_CLASS_COUNTER: AtomicU64 = AtomicU64::new(0);
2777
2778 fn unique_class_name(prefix: &str) -> String {
2779 let id = TEST_CLASS_COUNTER.fetch_add(1, Ordering::Relaxed);
2780 format!("{}_{}", prefix, id)
2781 }
2782
2783 #[test]
2784 fn primitive_classes_expose_static_zeros_method_metadata() {
2785 for class_name in ["double", "single", "logical"] {
2786 let class_def = get_class(class_name).expect("primitive class should be registered");
2787 let method = class_def
2788 .methods
2789 .get("zeros")
2790 .expect("primitive class should expose zeros static method");
2791 assert!(method.is_static, "zeros should be static on {class_name}");
2792 assert_eq!(method.function_name, "zeros");
2793 assert_eq!(method.implicit_class_argument.as_deref(), Some(class_name));
2794
2795 let (resolved, owner) =
2796 lookup_method(class_name, "zeros").expect("lookup should find primitive zeros");
2797 assert_eq!(owner, class_name);
2798 assert_eq!(resolved.function_name, "zeros");
2799 assert_eq!(
2800 resolved.implicit_class_argument.as_deref(),
2801 Some(class_name)
2802 );
2803 }
2804 }
2805
2806 #[test]
2807 fn method_lookup_uses_parent_class_metadata_chain() {
2808 let parent_name = unique_class_name("plan6_parent");
2809 let child_name = unique_class_name("plan6_child");
2810
2811 let mut parent_methods = HashMap::new();
2812 parent_methods.insert(
2813 "parentOnly".to_string(),
2814 MethodDef {
2815 name: "parentOnly".to_string(),
2816 is_static: false,
2817 is_abstract: false,
2818 is_sealed: false,
2819 access: Access::Public,
2820 function_name: "parentOnly_impl".to_string(),
2821 implicit_class_argument: None,
2822 },
2823 );
2824 register_class(ClassDef {
2825 name: parent_name.clone(),
2826 parent: None,
2827 properties: HashMap::new(),
2828 methods: parent_methods,
2829 });
2830 register_class(ClassDef {
2831 name: child_name.clone(),
2832 parent: Some(parent_name.clone()),
2833 properties: HashMap::new(),
2834 methods: HashMap::new(),
2835 });
2836
2837 let (method, owner) = lookup_method(&child_name, "parentOnly")
2838 .expect("child lookup should resolve inherited method through parent metadata");
2839 assert_eq!(owner, parent_name);
2840 assert_eq!(method.function_name, "parentOnly_impl");
2841 }
2842
2843 #[test]
2844 fn method_lookup_handles_parent_cycle() {
2845 let class_a = unique_class_name("plan6_cycle_method_a");
2846 let class_b = unique_class_name("plan6_cycle_method_b");
2847
2848 register_class(ClassDef {
2849 name: class_a.clone(),
2850 parent: Some(class_b.clone()),
2851 properties: HashMap::new(),
2852 methods: HashMap::new(),
2853 });
2854 register_class(ClassDef {
2855 name: class_b.clone(),
2856 parent: Some(class_a.clone()),
2857 properties: HashMap::new(),
2858 methods: HashMap::new(),
2859 });
2860
2861 assert!(
2862 lookup_method(&class_a, "missing").is_none(),
2863 "cyclic parent metadata should terminate missing method lookup"
2864 );
2865 }
2866
2867 #[test]
2868 fn property_lookup_uses_parent_class_metadata_chain() {
2869 let parent_name = unique_class_name("plan6_property_parent");
2870 let child_name = unique_class_name("plan6_property_child");
2871
2872 let mut parent_properties = HashMap::new();
2873 parent_properties.insert(
2874 "parentFlag".to_string(),
2875 PropertyDef {
2876 name: "parentFlag".to_string(),
2877 is_static: false,
2878 is_constant: false,
2879 is_dependent: false,
2880 get_access: Access::Public,
2881 set_access: Access::Public,
2882 default_value: None,
2883 },
2884 );
2885 register_class(ClassDef {
2886 name: parent_name.clone(),
2887 parent: None,
2888 properties: parent_properties,
2889 methods: HashMap::new(),
2890 });
2891 register_class(ClassDef {
2892 name: child_name.clone(),
2893 parent: Some(parent_name.clone()),
2894 properties: HashMap::new(),
2895 methods: HashMap::new(),
2896 });
2897
2898 let (property, owner) = lookup_property(&child_name, "parentFlag")
2899 .expect("child property lookup should resolve inherited property through parent");
2900 assert_eq!(owner, parent_name);
2901 assert_eq!(property.name, "parentFlag");
2902 assert!(!property.is_static);
2903 }
2904
2905 #[test]
2906 fn property_lookup_handles_parent_cycle() {
2907 let class_a = unique_class_name("plan6_cycle_property_a");
2908 let class_b = unique_class_name("plan6_cycle_property_b");
2909
2910 register_class(ClassDef {
2911 name: class_a.clone(),
2912 parent: Some(class_b.clone()),
2913 properties: HashMap::new(),
2914 methods: HashMap::new(),
2915 });
2916 register_class(ClassDef {
2917 name: class_b.clone(),
2918 parent: Some(class_a.clone()),
2919 properties: HashMap::new(),
2920 methods: HashMap::new(),
2921 });
2922
2923 assert!(
2924 lookup_property(&class_a, "missing").is_none(),
2925 "cyclic parent metadata should terminate missing property lookup"
2926 );
2927 }
2928}