runmat_runtime/builtins/common/
tensor.rs1use std::convert::TryFrom;
2
3use runmat_builtins::{LogicalArray, NumericDType, Tensor, Value};
4
5pub fn element_count(shape: &[usize]) -> usize {
7 let mut acc: u128 = 1;
8 for &dim in shape {
9 let dim128 = dim as u128;
10 acc = acc
11 .checked_mul(dim128)
12 .expect("tensor::element_count: overflow computing element count");
13 }
14 usize::try_from(acc).expect("tensor::element_count: overflow converting to usize")
15}
16
17pub fn zeros(shape: &[usize]) -> Result<Tensor, String> {
19 Tensor::new(vec![0.0; element_count(shape)], shape.to_vec())
20 .map_err(|e| format!("tensor zeros: {e}"))
21}
22
23pub fn ones(shape: &[usize]) -> Result<Tensor, String> {
25 Tensor::new(vec![1.0; element_count(shape)], shape.to_vec())
26 .map_err(|e| format!("tensor ones: {e}"))
27}
28
29pub fn zeros_with_dtype(shape: &[usize], dtype: NumericDType) -> Result<Tensor, String> {
31 Tensor::new_with_dtype(vec![0.0; element_count(shape)], shape.to_vec(), dtype)
32 .map_err(|e| format!("tensor zeros: {e}"))
33}
34
35pub fn ones_with_dtype(shape: &[usize], dtype: NumericDType) -> Result<Tensor, String> {
37 Tensor::new_with_dtype(vec![1.0; element_count(shape)], shape.to_vec(), dtype)
38 .map_err(|e| format!("tensor ones: {e}"))
39}
40
41pub fn logical_to_tensor(logical: &LogicalArray) -> Result<Tensor, String> {
43 let data: Vec<f64> = logical
44 .data
45 .iter()
46 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
47 .collect();
48 Tensor::new(data, logical.shape.clone()).map_err(|e| format!("logical->tensor: {e}"))
49}
50
51fn value_into_tensor_impl(name: &str, value: Value) -> Result<Tensor, String> {
52 match value {
53 Value::Tensor(t) => Ok(t),
54 Value::LogicalArray(logical) => logical_to_tensor(&logical),
55 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("tensor: {e}")),
56 Value::Int(i) => {
57 Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("tensor: {e}"))
58 }
59 Value::Bool(b) => Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
60 .map_err(|e| format!("tensor: {e}")),
61 other => Err(format!(
62 "{name}: unsupported input type {:?}; expected numeric or logical values",
63 other
64 )),
65 }
66}
67
68pub fn value_into_tensor(value: Value) -> Result<Tensor, String> {
70 value_into_tensor_impl("sum", value)
71}
72
73pub fn value_into_tensor_for(name: &str, value: Value) -> Result<Tensor, String> {
75 value_into_tensor_impl(name, value)
76}
77
78pub fn value_to_tensor(value: &Value) -> Result<Tensor, String> {
80 value_into_tensor(value.clone())
81}
82
83pub fn tensor_into_value(tensor: Tensor) -> Value {
88 if tensor.data.len() == 1 {
89 Value::Num(tensor.data[0])
90 } else {
91 Value::Tensor(tensor)
92 }
93}
94
95pub fn is_scalar_tensor(tensor: &Tensor) -> bool {
97 tensor.data.len() == 1
98}
99
100pub fn parse_dimension(value: &Value, name: &str) -> Result<usize, String> {
102 match value {
103 Value::Int(i) => {
104 let raw = i.to_i64();
105 if raw < 1 {
106 return Err(format!("{name}: dimension must be >= 1"));
107 }
108 Ok(raw as usize)
109 }
110 Value::Num(n) => {
111 if !n.is_finite() {
112 return Err(format!("{name}: dimension must be finite"));
113 }
114 let rounded = n.round();
115 if (rounded - n).abs() > 1e-6 {
117 return Err(format!("{name}: dimension must be an integer"));
118 }
119 if rounded < 1.0 {
120 return Err(format!("{name}: dimension must be >= 1"));
121 }
122 Ok(rounded as usize)
123 }
124 other => Err(format!(
125 "{name}: dimension must be numeric, got {:?}",
126 other
127 )),
128 }
129}
130
131pub fn value_to_string(value: &Value) -> Option<String> {
133 String::try_from(value).ok()
134}