use runmat_builtins::{CharArray, Tensor, Value};
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
fn concat_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message).build()
}
pub fn char_array_from_f64_with_prefix(value: f64, error_prefix: &str) -> BuiltinResult<CharArray> {
if !value.is_finite() || value.fract() != 0.0 {
return Err(concat_error(format!(
"{error_prefix}: expected integer code point"
)));
}
if value < 0.0 || value > u32::MAX as f64 {
return Err(concat_error(format!(
"{error_prefix}: code point out of range"
)));
}
let code = value as u32;
let ch = char::from_u32(code)
.ok_or_else(|| concat_error(format!("{error_prefix}: invalid code point")))?;
CharArray::new(vec![ch], 1, 1).map_err(concat_error)
}
fn char_array_from_f64(value: f64) -> BuiltinResult<CharArray> {
char_array_from_f64_with_prefix(value, "char concat")
}
pub fn hcat_matrices(a: &Tensor, b: &Tensor) -> BuiltinResult<Tensor> {
if a.rows() == 0 && a.cols() == 0 {
return Ok(b.clone());
}
if b.rows() == 0 && b.cols() == 0 {
return Ok(a.clone());
}
if a.rows() != b.rows() {
return Err(concat_error(format!(
"Cannot horizontally concatenate matrices with different row counts: {} vs {}",
a.rows, b.rows
)));
}
let new_rows = a.rows();
let new_cols = a.cols() + b.cols();
let mut new_data = Vec::with_capacity(new_rows * new_cols);
for col in 0..new_cols {
if col < a.cols() {
for row in 0..a.rows() {
new_data.push(a.data[row + col * a.rows()]);
}
} else {
let bcol = col - a.cols();
for row in 0..b.rows() {
new_data.push(b.data[row + bcol * b.rows()]);
}
}
}
Tensor::new_2d(new_data, new_rows, new_cols).map_err(concat_error)
}
pub fn vcat_matrices(a: &Tensor, b: &Tensor) -> BuiltinResult<Tensor> {
if a.rows() == 0 && a.cols() == 0 {
return Ok(b.clone());
}
if b.rows() == 0 && b.cols() == 0 {
return Ok(a.clone());
}
if a.cols() != b.cols() {
return Err(concat_error(format!(
"Cannot vertically concatenate matrices with different column counts: {} vs {}",
a.cols, b.cols
)));
}
let new_rows = a.rows() + b.rows();
let new_cols = a.cols();
let mut new_data = Vec::with_capacity(new_rows * new_cols);
for col in 0..a.cols() {
for row in 0..a.rows() {
new_data.push(a.data[row + col * a.rows()]);
}
}
for col in 0..b.cols() {
for row in 0..b.rows() {
new_data.push(b.data[row + col * b.rows()]);
}
}
Tensor::new_2d(new_data, new_rows, new_cols).map_err(concat_error)
}
pub fn hcat_values(values: &[Value]) -> BuiltinResult<Value> {
if values.is_empty() {
return Ok(Value::Tensor(
Tensor::new(vec![], vec![0, 0]).map_err(concat_error)?,
));
}
let has_str = values
.iter()
.any(|v| matches!(v, Value::String(_) | Value::StringArray(_)));
let has_char = values.iter().any(|v| matches!(v, Value::CharArray(_)));
if has_str {
let mut rows: Option<usize> = None;
let mut cols_total = 0usize;
let mut blocks: Vec<runmat_builtins::StringArray> = Vec::new();
for v in values {
match v {
Value::StringArray(sa) => {
if rows.is_none() {
rows = Some(sa.rows());
} else if rows != Some(sa.rows()) {
return Err(concat_error("string hcat: row mismatch"));
}
cols_total += sa.cols();
blocks.push(sa.clone());
}
Value::String(s) => {
let sa =
runmat_builtins::StringArray::new(vec![s.clone()], vec![1, 1]).unwrap();
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("string hcat: row mismatch"));
}
cols_total += 1;
blocks.push(sa);
}
Value::CharArray(ca) => {
if ca.rows == 0 {
continue;
}
if rows.is_none() {
rows = Some(ca.rows);
} else if rows != Some(ca.rows) {
return Err(concat_error("string hcat: row mismatch"));
}
let mut out: Vec<String> = Vec::with_capacity(ca.rows);
for r in 0..ca.rows {
let mut s = String::with_capacity(ca.cols);
for c in 0..ca.cols {
s.push(ca.data[r * ca.cols + c]);
}
out.push(s);
}
let sa = runmat_builtins::StringArray::new(out, vec![ca.rows, 1]).unwrap();
cols_total += 1;
blocks.push(sa);
}
Value::Num(n) => {
let sa =
runmat_builtins::StringArray::new(vec![n.to_string()], vec![1, 1]).unwrap();
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("string hcat: row mismatch"));
}
cols_total += 1;
blocks.push(sa);
}
Value::Complex(re, im) => {
let sa = runmat_builtins::StringArray::new(
vec![runmat_builtins::Value::Complex(*re, *im).to_string()],
vec![1, 1],
)
.unwrap();
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("string hcat: row mismatch"));
}
cols_total += 1;
blocks.push(sa);
}
Value::Int(i) => {
let sa =
runmat_builtins::StringArray::new(vec![i.to_i64().to_string()], vec![1, 1])
.unwrap();
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("string hcat: row mismatch"));
}
cols_total += 1;
blocks.push(sa);
}
Value::Tensor(_) | Value::Cell(_) => {
return Err(concat_error(format!(
"Cannot concatenate value of type {v:?} with string array"
)))
}
_ => {
return Err(concat_error(format!(
"Cannot concatenate value of type {v:?} with string array"
)))
}
}
}
let rows = rows.unwrap_or(0);
let mut data: Vec<String> = Vec::with_capacity(rows * cols_total);
for cacc in 0..cols_total {
let _ = cacc;
}
for block in &blocks {
for c in 0..block.cols() {
for r in 0..rows {
let idx = r + c * rows;
data.push(block.data[idx].clone());
}
}
}
let sa = runmat_builtins::StringArray::new(data, vec![rows, cols_total])
.map_err(|e| concat_error(format!("string hcat: {e}")))?;
return Ok(Value::StringArray(sa));
}
if has_char {
let mut rows: Option<usize> = None;
let mut cols_total = 0usize;
let mut blocks: Vec<CharArray> = Vec::new();
for v in values {
match v {
Value::CharArray(ca) => {
if ca.rows == 0 && ca.cols == 0 {
continue;
}
if rows.is_none() {
rows = Some(ca.rows);
} else if rows != Some(ca.rows) {
return Err(concat_error("char hcat: row mismatch"));
}
cols_total += ca.cols;
blocks.push(ca.clone());
}
Value::Num(n) => {
let ca = char_array_from_f64(*n)?;
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("char hcat: row mismatch"));
}
cols_total += 1;
blocks.push(ca);
}
Value::Int(i) => {
let ca = char_array_from_f64(i.to_f64())?;
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("char hcat: row mismatch"));
}
cols_total += 1;
blocks.push(ca);
}
Value::Bool(flag) => {
let ca = char_array_from_f64(if *flag { 1.0 } else { 0.0 })?;
if rows.is_none() {
rows = Some(1);
} else if rows != Some(1) {
return Err(concat_error("char hcat: row mismatch"));
}
cols_total += 1;
blocks.push(ca);
}
_ => {
return Err(concat_error(format!(
"Cannot concatenate value of type {v:?} with char array"
)))
}
}
}
let rows = rows.unwrap_or(0);
let mut data: Vec<char> = Vec::with_capacity(rows * cols_total);
for r in 0..rows {
for block in &blocks {
for c in 0..block.cols {
data.push(block.data[r * block.cols + c]);
}
}
}
let ca = CharArray::new(data, rows, cols_total)
.map_err(|e| concat_error(format!("char hcat: {e}")))?;
return Ok(Value::CharArray(ca));
}
let mut matrices = Vec::new();
let mut _total_cols = 0;
let mut rows = 0;
for value in values {
match value {
Value::Num(n) => {
let matrix = Tensor::new_2d(vec![*n], 1, 1).map_err(concat_error)?;
if rows == 0 {
rows = 1;
} else if rows != 1 {
return Err(concat_error(
"Cannot concatenate scalar with multi-row matrix",
));
}
_total_cols += 1;
matrices.push(matrix);
}
Value::Complex(re, _im) => {
let matrix = Tensor::new_2d(vec![*re], 1, 1).map_err(concat_error)?; if rows == 0 {
rows = 1;
} else if rows != 1 {
return Err(concat_error(
"Cannot concatenate scalar with multi-row matrix",
));
}
_total_cols += 1;
matrices.push(matrix);
}
Value::Int(i) => {
let matrix = Tensor::new_2d(vec![i.to_f64()], 1, 1).map_err(concat_error)?;
if rows == 0 {
rows = 1;
} else if rows != 1 {
return Err(concat_error(
"Cannot concatenate scalar with multi-row matrix",
));
}
_total_cols += 1;
matrices.push(matrix);
}
Value::Tensor(m) => {
if m.rows() == 0 && m.cols() == 0 {
continue;
}
if rows == 0 {
rows = m.rows();
} else if rows != m.rows() {
return Err(concat_error(format!(
"Cannot concatenate matrices with different row counts: {} vs {}",
rows,
m.rows()
)));
}
_total_cols += m.cols();
matrices.push(m.clone());
}
_ => {
return Err(concat_error(format!(
"Cannot concatenate value of type {value:?}"
)))
}
}
}
let mut result = matrices[0].clone();
for matrix in &matrices[1..] {
result = hcat_matrices(&result, matrix)?;
}
Ok(Value::Tensor(result))
}
pub fn vcat_values(values: &[Value]) -> BuiltinResult<Value> {
if values.is_empty() {
return Ok(Value::Tensor(
Tensor::new(vec![], vec![0, 0]).map_err(concat_error)?,
));
}
let has_str = values
.iter()
.any(|v| matches!(v, Value::String(_) | Value::StringArray(_)));
let has_char = values.iter().any(|v| matches!(v, Value::CharArray(_)));
if has_str {
let mut cols: Option<usize> = None;
let mut rows_total = 0usize;
let mut blocks: Vec<runmat_builtins::StringArray> = Vec::new();
for v in values {
match v {
Value::StringArray(sa) => {
if cols.is_none() {
cols = Some(sa.cols());
} else if cols != Some(sa.cols()) {
return Err(concat_error("string vcat: column mismatch"));
}
rows_total += sa.rows();
blocks.push(sa.clone());
}
Value::String(s) => {
let sa =
runmat_builtins::StringArray::new(vec![s.clone()], vec![1, 1]).unwrap();
rows_total += 1;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("string vcat: column mismatch"));
}
blocks.push(sa);
}
Value::CharArray(ca) => {
if ca.cols == 0 {
continue;
}
let out: String = ca.data.iter().collect();
let sa = runmat_builtins::StringArray::new(vec![out], vec![1, 1]).unwrap();
rows_total += 1;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("string vcat: column mismatch"));
}
blocks.push(sa);
}
Value::Num(n) => {
let sa =
runmat_builtins::StringArray::new(vec![n.to_string()], vec![1, 1]).unwrap();
rows_total += 1;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("string vcat: column mismatch"));
}
blocks.push(sa);
}
Value::Complex(re, im) => {
let sa = runmat_builtins::StringArray::new(
vec![runmat_builtins::Value::Complex(*re, *im).to_string()],
vec![1, 1],
)
.unwrap();
rows_total += 1;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("string vcat: column mismatch"));
}
blocks.push(sa);
}
Value::Int(i) => {
let sa =
runmat_builtins::StringArray::new(vec![i.to_i64().to_string()], vec![1, 1])
.unwrap();
rows_total += 1;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("string vcat: column mismatch"));
}
blocks.push(sa);
}
_ => {
return Err(concat_error(format!(
"Cannot concatenate value of type {v:?} with string array"
)))
}
}
}
let cols = cols.unwrap_or(0);
let mut data: Vec<String> = Vec::with_capacity(rows_total * cols);
for block in &blocks {
for c in 0..cols {
for r in 0..block.rows() {
let idx = r + c * block.rows();
data.push(block.data[idx].clone());
}
}
}
let sa = runmat_builtins::StringArray::new(data, vec![rows_total, cols])
.map_err(|e| concat_error(format!("string vcat: {e}")))?;
return Ok(Value::StringArray(sa));
}
if has_char {
let mut cols: Option<usize> = None;
let mut rows_total = 0usize;
let mut blocks: Vec<CharArray> = Vec::new();
for v in values {
match v {
Value::CharArray(ca) => {
if ca.rows == 0 && ca.cols == 0 {
continue;
}
if cols.is_none() {
cols = Some(ca.cols);
} else if cols != Some(ca.cols) {
return Err(concat_error("char vcat: column mismatch"));
}
rows_total += ca.rows;
blocks.push(ca.clone());
}
Value::Num(n) => {
let ca = char_array_from_f64(*n)?;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("char vcat: column mismatch"));
}
rows_total += 1;
blocks.push(ca);
}
Value::Int(i) => {
let ca = char_array_from_f64(i.to_f64())?;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("char vcat: column mismatch"));
}
rows_total += 1;
blocks.push(ca);
}
Value::Bool(flag) => {
let ca = char_array_from_f64(if *flag { 1.0 } else { 0.0 })?;
if cols.is_none() {
cols = Some(1);
} else if cols != Some(1) {
return Err(concat_error("char vcat: column mismatch"));
}
rows_total += 1;
blocks.push(ca);
}
_ => {
return Err(concat_error(format!(
"Cannot concatenate value of type {v:?} with char array"
)))
}
}
}
let cols = cols.unwrap_or(0);
let mut data: Vec<char> = Vec::with_capacity(rows_total * cols);
for block in &blocks {
for r in 0..block.rows {
for c in 0..cols {
data.push(block.data[r * block.cols + c]);
}
}
}
let ca = CharArray::new(data, rows_total, cols)
.map_err(|e| concat_error(format!("char vcat: {e}")))?;
return Ok(Value::CharArray(ca));
}
let mut matrices = Vec::new();
let mut _total_rows = 0;
let mut cols = 0;
for value in values {
match value {
Value::Num(n) => {
let matrix = Tensor::new_2d(vec![*n], 1, 1).map_err(concat_error)?;
if cols == 0 {
cols = 1;
} else if cols != 1 {
return Err(concat_error(
"Cannot concatenate scalar with multi-column matrix",
));
}
_total_rows += 1;
matrices.push(matrix);
}
Value::Complex(re, _im) => {
let matrix = Tensor::new_2d(vec![*re], 1, 1).map_err(concat_error)?;
if cols == 0 {
cols = 1;
} else if cols != 1 {
return Err(concat_error(
"Cannot concatenate scalar with multi-column matrix",
));
}
_total_rows += 1;
matrices.push(matrix);
}
Value::Int(i) => {
let matrix = Tensor::new_2d(vec![i.to_f64()], 1, 1).map_err(concat_error)?;
if cols == 0 {
cols = 1;
} else if cols != 1 {
return Err(concat_error(
"Cannot concatenate scalar with multi-column matrix",
));
}
_total_rows += 1;
matrices.push(matrix);
}
Value::Tensor(m) => {
if m.rows() == 0 && m.cols() == 0 {
continue;
}
if cols == 0 {
cols = m.cols();
} else if cols != m.cols() {
return Err(concat_error(format!(
"Cannot concatenate matrices with different column counts: {} vs {}",
cols,
m.cols()
)));
}
_total_rows += m.rows();
matrices.push(m.clone());
}
_ => {
return Err(concat_error(format!(
"Cannot concatenate value of type {value:?}"
)))
}
}
}
let mut result = matrices[0].clone();
for matrix in &matrices[1..] {
result = vcat_matrices(&result, matrix)?;
}
Ok(Value::Tensor(result))
}
pub async fn create_matrix_from_values(rows: &[Vec<Value>]) -> BuiltinResult<Value> {
if rows.is_empty() {
return Ok(Value::Tensor(
Tensor::new(vec![], vec![0, 0]).map_err(concat_error)?,
));
}
let mut row_matrices: Vec<Value> = Vec::with_capacity(rows.len());
for row in rows {
let row_value = if row.is_empty() {
Value::Tensor(Tensor::new(vec![], vec![0, 0]).map_err(concat_error)?)
} else {
crate::call_builtin_async("horzcat", row).await?
};
row_matrices.push(row_value);
}
if row_matrices.is_empty() {
Ok(Value::Tensor(
Tensor::new(vec![], vec![0, 0]).map_err(concat_error)?,
))
} else if row_matrices.len() == 1 {
Ok(row_matrices.into_iter().next().unwrap())
} else {
Ok(crate::call_builtin_async("vertcat", &row_matrices).await?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_hcat_matrices() {
let a = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
let b = Tensor::new_2d(vec![5.0, 6.0], 2, 1).unwrap();
let result = hcat_matrices(&a, &b).unwrap();
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 3);
assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_vcat_matrices() {
let a = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
let b = Tensor::new_2d(vec![3.0, 4.0], 1, 2).unwrap();
let result = vcat_matrices(&a, &b).unwrap();
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 2);
assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_hcat_values_scalars() {
let values = vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)];
let result = hcat_values(&values).unwrap();
if let Value::Tensor(m) = result {
assert_eq!(m.rows(), 1);
assert_eq!(m.cols(), 3);
assert_eq!(m.data, vec![1.0, 2.0, 3.0]);
} else {
panic!("Expected matrix result");
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_vcat_values_scalars() {
let values = vec![Value::Num(1.0), Value::Num(2.0)];
let result = vcat_values(&values).unwrap();
if let Value::Tensor(m) = result {
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 1);
assert_eq!(m.data, vec![1.0, 2.0]);
} else {
panic!("Expected matrix result");
}
}
}