runmat_runtime/
concatenation.rs

1//! Matrix and array concatenation operations
2//!
3//! This module provides language-compatible matrix concatenation operations.
4//! Supports both horizontal concatenation [A, B] and vertical concatenation [A; B].
5
6use runmat_builtins::{Tensor, Value};
7
8/// Horizontally concatenate two matrices [A, B]
9/// In language: C = [A, B] creates a matrix with A and B side by side
10pub fn hcat_matrices(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
11    // Language semantics: [] acts as a neutral element for concatenation
12    if a.rows() == 0 && a.cols() == 0 {
13        return Ok(b.clone());
14    }
15    if b.rows() == 0 && b.cols() == 0 {
16        return Ok(a.clone());
17    }
18    if a.rows() != b.rows() {
19        return Err(format!(
20            "Cannot horizontally concatenate matrices with different row counts: {} vs {}",
21            a.rows, b.rows
22        ));
23    }
24
25    let new_rows = a.rows();
26    let new_cols = a.cols() + b.cols();
27    let mut new_data = Vec::with_capacity(new_rows * new_cols);
28
29    // Column-major layout: build column-by-column
30    for col in 0..new_cols {
31        if col < a.cols() {
32            for row in 0..a.rows() {
33                new_data.push(a.data[row + col * a.rows()]);
34            }
35        } else {
36            let bcol = col - a.cols();
37            for row in 0..b.rows() {
38                new_data.push(b.data[row + bcol * b.rows()]);
39            }
40        }
41    }
42
43    Tensor::new_2d(new_data, new_rows, new_cols)
44}
45
46/// Vertically concatenate two matrices [A; B]
47/// In language: C = [A; B] creates a matrix with A on top and B below
48pub fn vcat_matrices(a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
49    // Language semantics: [] acts as a neutral element for concatenation
50    if a.rows() == 0 && a.cols() == 0 {
51        return Ok(b.clone());
52    }
53    if b.rows() == 0 && b.cols() == 0 {
54        return Ok(a.clone());
55    }
56    if a.cols() != b.cols() {
57        return Err(format!(
58            "Cannot vertically concatenate matrices with different column counts: {} vs {}",
59            a.cols, b.cols
60        ));
61    }
62
63    let new_rows = a.rows() + b.rows();
64    let new_cols = a.cols();
65    let mut new_data = Vec::with_capacity(new_rows * new_cols);
66
67    // Column-major: copy columns of A then columns of B
68    for col in 0..a.cols() {
69        for row in 0..a.rows() {
70            new_data.push(a.data[row + col * a.rows()]);
71        }
72    }
73    for col in 0..b.cols() {
74        for row in 0..b.rows() {
75            new_data.push(b.data[row + col * b.rows()]);
76        }
77    }
78
79    Tensor::new_2d(new_data, new_rows, new_cols)
80}
81
82/// Concatenate values horizontally - handles mixed scalars and matrices
83pub fn hcat_values(values: &[Value]) -> Result<Value, String> {
84    if values.is_empty() {
85        return Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?));
86    }
87
88    // If any operand is a string or string array, perform string-array concatenation
89    let has_str = values.iter().any(|v| {
90        matches!(
91            v,
92            Value::String(_) | Value::StringArray(_) | Value::CharArray(_)
93        )
94    });
95    if has_str {
96        // Normalize all to string-arrays, then horizontal concat by columns
97        // Determine row count: if any is string array, its rows; if string scalar or numeric scalar, rows=1
98        let mut rows: Option<usize> = None;
99        let mut cols_total = 0usize;
100        let mut blocks: Vec<runmat_builtins::StringArray> = Vec::new();
101        for v in values {
102            match v {
103                Value::StringArray(sa) => {
104                    if rows.is_none() {
105                        rows = Some(sa.rows());
106                    } else if rows != Some(sa.rows()) {
107                        return Err("string hcat: row mismatch".to_string());
108                    }
109                    cols_total += sa.cols();
110                    blocks.push(sa.clone());
111                }
112                Value::String(s) => {
113                    let sa =
114                        runmat_builtins::StringArray::new(vec![s.clone()], vec![1, 1]).unwrap();
115                    if rows.is_none() {
116                        rows = Some(1);
117                    } else if rows != Some(1) {
118                        return Err("string hcat: row mismatch".to_string());
119                    }
120                    cols_total += 1;
121                    blocks.push(sa);
122                }
123                Value::CharArray(ca) => {
124                    // Convert char array to string array by rows
125                    if ca.rows == 0 {
126                        continue;
127                    }
128                    if rows.is_none() {
129                        rows = Some(ca.rows);
130                    } else if rows != Some(ca.rows) {
131                        return Err("string hcat: row mismatch".to_string());
132                    }
133                    let mut out: Vec<String> = Vec::with_capacity(ca.rows);
134                    for r in 0..ca.rows {
135                        let mut s = String::with_capacity(ca.cols);
136                        for c in 0..ca.cols {
137                            s.push(ca.data[r * ca.cols + c]);
138                        }
139                        out.push(s);
140                    }
141                    let sa = runmat_builtins::StringArray::new(out, vec![ca.rows, 1]).unwrap();
142                    cols_total += 1;
143                    blocks.push(sa);
144                }
145                Value::Num(n) => {
146                    let sa =
147                        runmat_builtins::StringArray::new(vec![n.to_string()], vec![1, 1]).unwrap();
148                    if rows.is_none() {
149                        rows = Some(1);
150                    } else if rows != Some(1) {
151                        return Err("string hcat: row mismatch".to_string());
152                    }
153                    cols_total += 1;
154                    blocks.push(sa);
155                }
156                Value::Complex(re, im) => {
157                    let sa = runmat_builtins::StringArray::new(
158                        vec![runmat_builtins::Value::Complex(*re, *im).to_string()],
159                        vec![1, 1],
160                    )
161                    .unwrap();
162                    if rows.is_none() {
163                        rows = Some(1);
164                    } else if rows != Some(1) {
165                        return Err("string hcat: row mismatch".to_string());
166                    }
167                    cols_total += 1;
168                    blocks.push(sa);
169                }
170                Value::Int(i) => {
171                    let sa =
172                        runmat_builtins::StringArray::new(vec![i.to_i64().to_string()], vec![1, 1])
173                            .unwrap();
174                    if rows.is_none() {
175                        rows = Some(1);
176                    } else if rows != Some(1) {
177                        return Err("string hcat: row mismatch".to_string());
178                    }
179                    cols_total += 1;
180                    blocks.push(sa);
181                }
182                Value::Tensor(_) | Value::Cell(_) => {
183                    return Err(format!(
184                        "Cannot concatenate value of type {v:?} with string array"
185                    ))
186                }
187                _ => {
188                    return Err(format!(
189                        "Cannot concatenate value of type {v:?} with string array"
190                    ))
191                }
192            }
193        }
194        let rows = rows.unwrap_or(0);
195        let mut data: Vec<String> = Vec::with_capacity(rows * cols_total);
196        for cacc in 0..cols_total {
197            let _ = cacc;
198        }
199        // Stitch columns block-by-block in column-major
200        for block in &blocks {
201            for c in 0..block.cols() {
202                for r in 0..rows {
203                    let idx = r + c * rows;
204                    data.push(block.data[idx].clone());
205                }
206            }
207        }
208        let sa = runmat_builtins::StringArray::new(data, vec![rows, cols_total])
209            .map_err(|e| format!("string hcat: {e}"))?;
210        return Ok(Value::StringArray(sa));
211    }
212
213    // Convert all scalars to 1x1 matrices for uniform processing
214    let mut matrices = Vec::new();
215    let mut _total_cols = 0;
216    let mut rows = 0;
217
218    for value in values {
219        match value {
220            Value::Num(n) => {
221                let matrix = Tensor::new_2d(vec![*n], 1, 1)?;
222                if rows == 0 {
223                    rows = 1;
224                } else if rows != 1 {
225                    return Err("Cannot concatenate scalar with multi-row matrix".to_string());
226                }
227                _total_cols += 1;
228                matrices.push(matrix);
229            }
230            Value::Complex(re, _im) => {
231                let matrix = Tensor::new_2d(vec![*re], 1, 1)?; // real part in numeric hcat coercion
232                if rows == 0 {
233                    rows = 1;
234                } else if rows != 1 {
235                    return Err("Cannot concatenate scalar with multi-row matrix".to_string());
236                }
237                _total_cols += 1;
238                matrices.push(matrix);
239            }
240            Value::Int(i) => {
241                let matrix = Tensor::new_2d(vec![i.to_f64()], 1, 1)?;
242                if rows == 0 {
243                    rows = 1;
244                } else if rows != 1 {
245                    return Err("Cannot concatenate scalar with multi-row matrix".to_string());
246                }
247                _total_cols += 1;
248                matrices.push(matrix);
249            }
250            Value::Tensor(m) => {
251                // Skip true empty 0x0 operands (neutral element)
252                if m.rows() == 0 && m.cols() == 0 {
253                    continue;
254                }
255                if rows == 0 {
256                    rows = m.rows();
257                } else if rows != m.rows() {
258                    return Err(format!(
259                        "Cannot concatenate matrices with different row counts: {} vs {}",
260                        rows,
261                        m.rows()
262                    ));
263                }
264                _total_cols += m.cols();
265                matrices.push(m.clone());
266            }
267            _ => return Err(format!("Cannot concatenate value of type {value:?}")),
268        }
269    }
270
271    // Now concatenate all matrices horizontally
272    let mut result = matrices[0].clone();
273    for matrix in &matrices[1..] {
274        result = hcat_matrices(&result, matrix)?;
275    }
276
277    Ok(Value::Tensor(result))
278}
279
280/// Concatenate values vertically - handles mixed scalars and matrices
281pub fn vcat_values(values: &[Value]) -> Result<Value, String> {
282    if values.is_empty() {
283        return Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?));
284    }
285
286    // If any operand is a string or string array, perform string-array vertical concatenation by stacking rows
287    let has_str = values.iter().any(|v| {
288        matches!(
289            v,
290            Value::String(_) | Value::StringArray(_) | Value::CharArray(_)
291        )
292    });
293    if has_str {
294        // Normalize to string-arrays; for scalars, treat as 1x1
295        let mut cols: Option<usize> = None;
296        let mut rows_total = 0usize;
297        let mut blocks: Vec<runmat_builtins::StringArray> = Vec::new();
298        for v in values {
299            match v {
300                Value::StringArray(sa) => {
301                    if cols.is_none() {
302                        cols = Some(sa.cols());
303                    } else if cols != Some(sa.cols()) {
304                        return Err("string vcat: column mismatch".to_string());
305                    }
306                    rows_total += sa.rows();
307                    blocks.push(sa.clone());
308                }
309                Value::String(s) => {
310                    let sa =
311                        runmat_builtins::StringArray::new(vec![s.clone()], vec![1, 1]).unwrap();
312                    rows_total += 1;
313                    if cols.is_none() {
314                        cols = Some(1);
315                    } else if cols != Some(1) {
316                        return Err("string vcat: column mismatch".to_string());
317                    }
318                    blocks.push(sa);
319                }
320                Value::CharArray(ca) => {
321                    if ca.cols == 0 {
322                        continue;
323                    }
324                    let out: String = ca.data.iter().collect();
325                    let sa = runmat_builtins::StringArray::new(vec![out], vec![1, 1]).unwrap();
326                    rows_total += 1;
327                    if cols.is_none() {
328                        cols = Some(1);
329                    } else if cols != Some(1) {
330                        return Err("string vcat: column mismatch".to_string());
331                    }
332                    blocks.push(sa);
333                }
334                Value::Num(n) => {
335                    let sa =
336                        runmat_builtins::StringArray::new(vec![n.to_string()], vec![1, 1]).unwrap();
337                    rows_total += 1;
338                    if cols.is_none() {
339                        cols = Some(1);
340                    } else if cols != Some(1) {
341                        return Err("string vcat: column mismatch".to_string());
342                    }
343                    blocks.push(sa);
344                }
345                Value::Complex(re, im) => {
346                    let sa = runmat_builtins::StringArray::new(
347                        vec![runmat_builtins::Value::Complex(*re, *im).to_string()],
348                        vec![1, 1],
349                    )
350                    .unwrap();
351                    rows_total += 1;
352                    if cols.is_none() {
353                        cols = Some(1);
354                    } else if cols != Some(1) {
355                        return Err("string vcat: column mismatch".to_string());
356                    }
357                    blocks.push(sa);
358                }
359                Value::Int(i) => {
360                    let sa =
361                        runmat_builtins::StringArray::new(vec![i.to_i64().to_string()], vec![1, 1])
362                            .unwrap();
363                    rows_total += 1;
364                    if cols.is_none() {
365                        cols = Some(1);
366                    } else if cols != Some(1) {
367                        return Err("string vcat: column mismatch".to_string());
368                    }
369                    blocks.push(sa);
370                }
371                _ => {
372                    return Err(format!(
373                        "Cannot concatenate value of type {v:?} with string array"
374                    ))
375                }
376            }
377        }
378        let cols = cols.unwrap_or(0);
379        let mut data: Vec<String> = Vec::with_capacity(rows_total * cols);
380        // Stack rows: copy columns for each block into data
381        for block in &blocks {
382            for c in 0..cols {
383                for r in 0..block.rows() {
384                    let idx = r + c * block.rows();
385                    data.push(block.data[idx].clone());
386                }
387            }
388        }
389        let sa = runmat_builtins::StringArray::new(data, vec![rows_total, cols])
390            .map_err(|e| format!("string vcat: {e}"))?;
391        return Ok(Value::StringArray(sa));
392    }
393
394    // Convert all scalars to 1x1 matrices for uniform processing
395    let mut matrices = Vec::new();
396    let mut _total_rows = 0;
397    let mut cols = 0;
398
399    for value in values {
400        match value {
401            Value::Num(n) => {
402                let matrix = Tensor::new_2d(vec![*n], 1, 1)?;
403                if cols == 0 {
404                    cols = 1;
405                } else if cols != 1 {
406                    return Err("Cannot concatenate scalar with multi-column matrix".to_string());
407                }
408                _total_rows += 1;
409                matrices.push(matrix);
410            }
411            Value::Complex(re, _im) => {
412                let matrix = Tensor::new_2d(vec![*re], 1, 1)?;
413                if cols == 0 {
414                    cols = 1;
415                } else if cols != 1 {
416                    return Err("Cannot concatenate scalar with multi-column matrix".to_string());
417                }
418                _total_rows += 1;
419                matrices.push(matrix);
420            }
421            Value::Int(i) => {
422                let matrix = Tensor::new_2d(vec![i.to_f64()], 1, 1)?;
423                if cols == 0 {
424                    cols = 1;
425                } else if cols != 1 {
426                    return Err("Cannot concatenate scalar with multi-column matrix".to_string());
427                }
428                _total_rows += 1;
429                matrices.push(matrix);
430            }
431            Value::Tensor(m) => {
432                // Skip true empty 0x0 operands (neutral element)
433                if m.rows() == 0 && m.cols() == 0 {
434                    continue;
435                }
436                if cols == 0 {
437                    cols = m.cols();
438                } else if cols != m.cols() {
439                    return Err(format!(
440                        "Cannot concatenate matrices with different column counts: {} vs {}",
441                        cols,
442                        m.cols()
443                    ));
444                }
445                _total_rows += m.rows();
446                matrices.push(m.clone());
447            }
448            _ => return Err(format!("Cannot concatenate value of type {value:?}")),
449        }
450    }
451
452    // Now concatenate all matrices vertically
453    let mut result = matrices[0].clone();
454    for matrix in &matrices[1..] {
455        result = vcat_matrices(&result, matrix)?;
456    }
457
458    Ok(Value::Tensor(result))
459}
460
461/// Create a matrix from a 2D array of Values with proper concatenation semantics
462/// This handles the case where matrix elements can be variables, not just literals
463pub fn create_matrix_from_values(rows: &[Vec<Value>]) -> Result<Value, String> {
464    if rows.is_empty() {
465        return Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?));
466    }
467
468    // Validate rectangularity
469    let mut all_cols_equal = true;
470    let cols = rows[0].len();
471    for r in rows {
472        if r.len() != cols {
473            all_cols_equal = false;
474            break;
475        }
476    }
477    if !all_cols_equal {
478        return Err("Matrix construction: inconsistent number of columns in rows".to_string());
479    }
480
481    // First, concatenate each row horizontally
482    let mut row_matrices: Vec<Value> = Vec::with_capacity(rows.len());
483    for row in rows {
484        let row_result = hcat_values(row)?;
485        row_matrices.push(row_result);
486    }
487
488    // Then concatenate all rows vertically
489    match row_matrices.len() {
490        0 => Ok(Value::Tensor(Tensor::new(vec![], vec![0, 0])?)),
491        1 => Ok(row_matrices.into_iter().next().unwrap()),
492        _ => vcat_values(&row_matrices),
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_hcat_matrices() {
502        let a = Tensor::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
503        let b = Tensor::new_2d(vec![5.0, 6.0], 2, 1).unwrap();
504
505        let result = hcat_matrices(&a, &b).unwrap();
506        assert_eq!(result.rows(), 2);
507        assert_eq!(result.cols(), 3);
508        // Column-major result: [ [1 3 5]; [2 4 6] ] data
509        assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
510    }
511
512    #[test]
513    fn test_vcat_matrices() {
514        let a = Tensor::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
515        let b = Tensor::new_2d(vec![3.0, 4.0], 1, 2).unwrap();
516
517        let result = vcat_matrices(&a, &b).unwrap();
518        assert_eq!(result.rows(), 2);
519        assert_eq!(result.cols(), 2);
520        // Column-major: columns preserved
521        // With our current vcat implementation, data appends column-wise preserving row order within each input
522        // For 1x2 stacked over 1x2, result data is [1,2,3,4]
523        assert_eq!(result.data, vec![1.0, 2.0, 3.0, 4.0]);
524    }
525
526    #[test]
527    fn test_hcat_values_scalars() {
528        let values = vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)];
529        let result = hcat_values(&values).unwrap();
530
531        if let Value::Tensor(m) = result {
532            assert_eq!(m.rows(), 1);
533            assert_eq!(m.cols(), 3);
534            // Column-major: 1x3 row vector still row-major visually, data order follows cols
535            assert_eq!(m.data, vec![1.0, 2.0, 3.0]);
536        } else {
537            panic!("Expected matrix result");
538        }
539    }
540
541    #[test]
542    fn test_vcat_values_scalars() {
543        let values = vec![Value::Num(1.0), Value::Num(2.0)];
544        let result = vcat_values(&values).unwrap();
545
546        if let Value::Tensor(m) = result {
547            assert_eq!(m.rows(), 2);
548            assert_eq!(m.cols(), 1);
549            assert_eq!(m.data, vec![1.0, 2.0]);
550        } else {
551            panic!("Expected matrix result");
552        }
553    }
554}