1use runmat_builtins::{Tensor, Value};
7
8pub fn hcat_matrices(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
11 if a.rows() == 0 && a.cols() == 0 {
13 return Ok(b.clone());
14 }
15 if b.rows() == 0 && b.cols() == 0 {
16 return Ok(a.clone());
17 }
18 if a.rows() != b.rows() {
19 return Err(format!(
20 "Cannot horizontally concatenate matrices with different row counts: {} vs {}",
21 a.rows, b.rows
22 ));
23 }
24
25 let new_rows = a.rows();
26 let new_cols = a.cols() + b.cols();
27 let mut new_data = Vec::with_capacity(new_rows * new_cols);
28
29 for col in 0..new_cols {
31 if col < a.cols() {
32 for row in 0..a.rows() {
33 new_data.push(a.data[row + col * a.rows()]);
34 }
35 } else {
36 let bcol = col - a.cols();
37 for row in 0..b.rows() {
38 new_data.push(b.data[row + bcol * b.rows()]);
39 }
40 }
41 }
42
43 Tensor::new_2d(new_data, new_rows, new_cols)
44}
45
46pub fn vcat_matrices(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
49 if a.rows() == 0 && a.cols() == 0 {
51 return Ok(b.clone());
52 }
53 if b.rows() == 0 && b.cols() == 0 {
54 return Ok(a.clone());
55 }
56 if a.cols() != b.cols() {
57 return Err(format!(
58 "Cannot vertically concatenate matrices with different column counts: {} vs {}",
59 a.cols, b.cols
60 ));
61 }
62
63 let new_rows = a.rows() + b.rows();
64 let new_cols = a.cols();
65 let mut new_data = Vec::with_capacity(new_rows * new_cols);
66
67 for col in 0..a.cols() {
69 for row in 0..a.rows() {
70 new_data.push(a.data[row + col * a.rows()]);
71 }
72 }
73 for col in 0..b.cols() {
74 for row in 0..b.rows() {
75 new_data.push(b.data[row + col * b.rows()]);
76 }
77 }
78
79 Tensor::new_2d(new_data, new_rows, new_cols)
80}
81
82pub fn hcat_values(values: &[Value]) -> Result<Value, String> {
84 if values.is_empty() {
85 return Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?));
86 }
87
88 let has_str = values.iter().any(|v| {
90 matches!(
91 v,
92 Value::String(_) | Value::StringArray(_) | Value::CharArray(_)
93 )
94 });
95 if has_str {
96 let mut rows: Option<usize> = None;
99 let mut cols_total = 0usize;
100 let mut blocks: Vec<runmat_builtins::StringArray> = Vec::new();
101 for v in values {
102 match v {
103 Value::StringArray(sa) => {
104 if rows.is_none() {
105 rows = Some(sa.rows());
106 } else if rows != Some(sa.rows()) {
107 return Err("string hcat: row mismatch".to_string());
108 }
109 cols_total += sa.cols();
110 blocks.push(sa.clone());
111 }
112 Value::String(s) => {
113 let sa =
114 runmat_builtins::StringArray::new(vec![s.clone()], vec![1, 1]).unwrap();
115 if rows.is_none() {
116 rows = Some(1);
117 } else if rows != Some(1) {
118 return Err("string hcat: row mismatch".to_string());
119 }
120 cols_total += 1;
121 blocks.push(sa);
122 }
123 Value::CharArray(ca) => {
124 if ca.rows == 0 {
126 continue;
127 }
128 if rows.is_none() {
129 rows = Some(ca.rows);
130 } else if rows != Some(ca.rows) {
131 return Err("string hcat: row mismatch".to_string());
132 }
133 let mut out: Vec<String> = Vec::with_capacity(ca.rows);
134 for r in 0..ca.rows {
135 let mut s = String::with_capacity(ca.cols);
136 for c in 0..ca.cols {
137 s.push(ca.data[r * ca.cols + c]);
138 }
139 out.push(s);
140 }
141 let sa = runmat_builtins::StringArray::new(out, vec![ca.rows, 1]).unwrap();
142 cols_total += 1;
143 blocks.push(sa);
144 }
145 Value::Num(n) => {
146 let sa =
147 runmat_builtins::StringArray::new(vec![n.to_string()], vec![1, 1]).unwrap();
148 if rows.is_none() {
149 rows = Some(1);
150 } else if rows != Some(1) {
151 return Err("string hcat: row mismatch".to_string());
152 }
153 cols_total += 1;
154 blocks.push(sa);
155 }
156 Value::Complex(re, im) => {
157 let sa = runmat_builtins::StringArray::new(
158 vec![runmat_builtins::Value::Complex(*re, *im).to_string()],
159 vec![1, 1],
160 )
161 .unwrap();
162 if rows.is_none() {
163 rows = Some(1);
164 } else if rows != Some(1) {
165 return Err("string hcat: row mismatch".to_string());
166 }
167 cols_total += 1;
168 blocks.push(sa);
169 }
170 Value::Int(i) => {
171 let sa =
172 runmat_builtins::StringArray::new(vec![i.to_i64().to_string()], vec![1, 1])
173 .unwrap();
174 if rows.is_none() {
175 rows = Some(1);
176 } else if rows != Some(1) {
177 return Err("string hcat: row mismatch".to_string());
178 }
179 cols_total += 1;
180 blocks.push(sa);
181 }
182 Value::Tensor(_) | Value::Cell(_) => {
183 return Err(format!(
184 "Cannot concatenate value of type {v:?} with string array"
185 ))
186 }
187 _ => {
188 return Err(format!(
189 "Cannot concatenate value of type {v:?} with string array"
190 ))
191 }
192 }
193 }
194 let rows = rows.unwrap_or(0);
195 let mut data: Vec<String> = Vec::with_capacity(rows * cols_total);
196 for cacc in 0..cols_total {
197 let _ = cacc;
198 }
199 for block in &blocks {
201 for c in 0..block.cols() {
202 for r in 0..rows {
203 let idx = r + c * rows;
204 data.push(block.data[idx].clone());
205 }
206 }
207 }
208 let sa = runmat_builtins::StringArray::new(data, vec![rows, cols_total])
209 .map_err(|e| format!("string hcat: {e}"))?;
210 return Ok(Value::StringArray(sa));
211 }
212
213 let mut matrices = Vec::new();
215 let mut _total_cols = 0;
216 let mut rows = 0;
217
218 for value in values {
219 match value {
220 Value::Num(n) => {
221 let matrix = Tensor::new_2d(vec![*n], 1, 1)?;
222 if rows == 0 {
223 rows = 1;
224 } else if rows != 1 {
225 return Err("Cannot concatenate scalar with multi-row matrix".to_string());
226 }
227 _total_cols += 1;
228 matrices.push(matrix);
229 }
230 Value::Complex(re, _im) => {
231 let matrix = Tensor::new_2d(vec![*re], 1, 1)?; if rows == 0 {
233 rows = 1;
234 } else if rows != 1 {
235 return Err("Cannot concatenate scalar with multi-row matrix".to_string());
236 }
237 _total_cols += 1;
238 matrices.push(matrix);
239 }
240 Value::Int(i) => {
241 let matrix = Tensor::new_2d(vec![i.to_f64()], 1, 1)?;
242 if rows == 0 {
243 rows = 1;
244 } else if rows != 1 {
245 return Err("Cannot concatenate scalar with multi-row matrix".to_string());
246 }
247 _total_cols += 1;
248 matrices.push(matrix);
249 }
250 Value::Tensor(m) => {
251 if m.rows() == 0 && m.cols() == 0 {
253 continue;
254 }
255 if rows == 0 {
256 rows = m.rows();
257 } else if rows != m.rows() {
258 return Err(format!(
259 "Cannot concatenate matrices with different row counts: {} vs {}",
260 rows,
261 m.rows()
262 ));
263 }
264 _total_cols += m.cols();
265 matrices.push(m.clone());
266 }
267 _ => return Err(format!("Cannot concatenate value of type {value:?}")),
268 }
269 }
270
271 let mut result = matrices[0].clone();
273 for matrix in &matrices[1..] {
274 result = hcat_matrices(&result, matrix)?;
275 }
276
277 Ok(Value::Tensor(result))
278}
279
280pub fn vcat_values(values: &[Value]) -> Result<Value, String> {
282 if values.is_empty() {
283 return Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?));
284 }
285
286 let has_str = values.iter().any(|v| {
288 matches!(
289 v,
290 Value::String(_) | Value::StringArray(_) | Value::CharArray(_)
291 )
292 });
293 if has_str {
294 let mut cols: Option<usize> = None;
296 let mut rows_total = 0usize;
297 let mut blocks: Vec<runmat_builtins::StringArray> = Vec::new();
298 for v in values {
299 match v {
300 Value::StringArray(sa) => {
301 if cols.is_none() {
302 cols = Some(sa.cols());
303 } else if cols != Some(sa.cols()) {
304 return Err("string vcat: column mismatch".to_string());
305 }
306 rows_total += sa.rows();
307 blocks.push(sa.clone());
308 }
309 Value::String(s) => {
310 let sa =
311 runmat_builtins::StringArray::new(vec![s.clone()], vec![1, 1]).unwrap();
312 rows_total += 1;
313 if cols.is_none() {
314 cols = Some(1);
315 } else if cols != Some(1) {
316 return Err("string vcat: column mismatch".to_string());
317 }
318 blocks.push(sa);
319 }
320 Value::CharArray(ca) => {
321 if ca.cols == 0 {
322 continue;
323 }
324 let out: String = ca.data.iter().collect();
325 let sa = runmat_builtins::StringArray::new(vec![out], vec![1, 1]).unwrap();
326 rows_total += 1;
327 if cols.is_none() {
328 cols = Some(1);
329 } else if cols != Some(1) {
330 return Err("string vcat: column mismatch".to_string());
331 }
332 blocks.push(sa);
333 }
334 Value::Num(n) => {
335 let sa =
336 runmat_builtins::StringArray::new(vec![n.to_string()], vec![1, 1]).unwrap();
337 rows_total += 1;
338 if cols.is_none() {
339 cols = Some(1);
340 } else if cols != Some(1) {
341 return Err("string vcat: column mismatch".to_string());
342 }
343 blocks.push(sa);
344 }
345 Value::Complex(re, im) => {
346 let sa = runmat_builtins::StringArray::new(
347 vec![runmat_builtins::Value::Complex(*re, *im).to_string()],
348 vec![1, 1],
349 )
350 .unwrap();
351 rows_total += 1;
352 if cols.is_none() {
353 cols = Some(1);
354 } else if cols != Some(1) {
355 return Err("string vcat: column mismatch".to_string());
356 }
357 blocks.push(sa);
358 }
359 Value::Int(i) => {
360 let sa =
361 runmat_builtins::StringArray::new(vec![i.to_i64().to_string()], vec![1, 1])
362 .unwrap();
363 rows_total += 1;
364 if cols.is_none() {
365 cols = Some(1);
366 } else if cols != Some(1) {
367 return Err("string vcat: column mismatch".to_string());
368 }
369 blocks.push(sa);
370 }
371 _ => {
372 return Err(format!(
373 "Cannot concatenate value of type {v:?} with string array"
374 ))
375 }
376 }
377 }
378 let cols = cols.unwrap_or(0);
379 let mut data: Vec<String> = Vec::with_capacity(rows_total * cols);
380 for block in &blocks {
382 for c in 0..cols {
383 for r in 0..block.rows() {
384 let idx = r + c * block.rows();
385 data.push(block.data[idx].clone());
386 }
387 }
388 }
389 let sa = runmat_builtins::StringArray::new(data, vec![rows_total, cols])
390 .map_err(|e| format!("string vcat: {e}"))?;
391 return Ok(Value::StringArray(sa));
392 }
393
394 let mut matrices = Vec::new();
396 let mut _total_rows = 0;
397 let mut cols = 0;
398
399 for value in values {
400 match value {
401 Value::Num(n) => {
402 let matrix = Tensor::new_2d(vec![*n], 1, 1)?;
403 if cols == 0 {
404 cols = 1;
405 } else if cols != 1 {
406 return Err("Cannot concatenate scalar with multi-column matrix".to_string());
407 }
408 _total_rows += 1;
409 matrices.push(matrix);
410 }
411 Value::Complex(re, _im) => {
412 let matrix = Tensor::new_2d(vec![*re], 1, 1)?;
413 if cols == 0 {
414 cols = 1;
415 } else if cols != 1 {
416 return Err("Cannot concatenate scalar with multi-column matrix".to_string());
417 }
418 _total_rows += 1;
419 matrices.push(matrix);
420 }
421 Value::Int(i) => {
422 let matrix = Tensor::new_2d(vec![i.to_f64()], 1, 1)?;
423 if cols == 0 {
424 cols = 1;
425 } else if cols != 1 {
426 return Err("Cannot concatenate scalar with multi-column matrix".to_string());
427 }
428 _total_rows += 1;
429 matrices.push(matrix);
430 }
431 Value::Tensor(m) => {
432 if m.rows() == 0 && m.cols() == 0 {
434 continue;
435 }
436 if cols == 0 {
437 cols = m.cols();
438 } else if cols != m.cols() {
439 return Err(format!(
440 "Cannot concatenate matrices with different column counts: {} vs {}",
441 cols,
442 m.cols()
443 ));
444 }
445 _total_rows += m.rows();
446 matrices.push(m.clone());
447 }
448 _ => return Err(format!("Cannot concatenate value of type {value:?}")),
449 }
450 }
451
452 let mut result = matrices[0].clone();
454 for matrix in &matrices[1..] {
455 result = vcat_matrices(&result, matrix)?;
456 }
457
458 Ok(Value::Tensor(result))
459}
460
461pub fn create_matrix_from_values(rows: &[Vec<Value>]) -> Result<Value, String> {
464 if rows.is_empty() {
465 return Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?));
466 }
467
468 let mut all_cols_equal = true;
470 let cols = rows[0].len();
471 for r in rows {
472 if r.len() != cols {
473 all_cols_equal = false;
474 break;
475 }
476 }
477 if !all_cols_equal {
478 return Err("Matrix construction: inconsistent number of columns in rows".to_string());
479 }
480
481 let mut row_matrices: Vec<Value> = Vec::with_capacity(rows.len());
483 for row in rows {
484 let row_result = hcat_values(row)?;
485 row_matrices.push(row_result);
486 }
487
488 match row_matrices.len() {
490 0 => Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?)),
491 1 => Ok(row_matrices.into_iter().next().unwrap()),
492 _ => vcat_values(&row_matrices),
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_hcat_matrices() {
502 let a = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
503 let b = Tensor::new_2d(vec![5.0, 6.0], 2, 1).unwrap();
504
505 let result = hcat_matrices(&a, &b).unwrap();
506 assert_eq!(result.rows(), 2);
507 assert_eq!(result.cols(), 3);
508 assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
510 }
511
512 #[test]
513 fn test_vcat_matrices() {
514 let a = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
515 let b = Tensor::new_2d(vec![3.0, 4.0], 1, 2).unwrap();
516
517 let result = vcat_matrices(&a, &b).unwrap();
518 assert_eq!(result.rows(), 2);
519 assert_eq!(result.cols(), 2);
520 assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0]);
524 }
525
526 #[test]
527 fn test_hcat_values_scalars() {
528 let values = vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)];
529 let result = hcat_values(&values).unwrap();
530
531 if let Value::Tensor(m) = result {
532 assert_eq!(m.rows(), 1);
533 assert_eq!(m.cols(), 3);
534 assert_eq!(m.data, vec![1.0, 2.0, 3.0]);
536 } else {
537 panic!("Expected matrix result");
538 }
539 }
540
541 #[test]
542 fn test_vcat_values_scalars() {
543 let values = vec![Value::Num(1.0), Value::Num(2.0)];
544 let result = vcat_values(&values).unwrap();
545
546 if let Value::Tensor(m) = result {
547 assert_eq!(m.rows(), 2);
548 assert_eq!(m.cols(), 1);
549 assert_eq!(m.data, vec![1.0, 2.0]);
550 } else {
551 panic!("Expected matrix result");
552 }
553 }
554}