Skip to main content

runmat_runtime/builtins/strings/transform/
join.rs

1//! MATLAB-compatible `join` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9    ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_concat_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::join")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17    name: "join",
18    op_kind: GpuOpKind::Custom("string-transform"),
19    supported_precisions: &[],
20    broadcast: BroadcastSemantics::None,
21    provider_hooks: &[],
22    constant_strategy: ConstantStrategy::InlineLiteral,
23    residency: ResidencyPolicy::GatherImmediately,
24    nan_mode: ReductionNaN::Include,
25    two_pass_threshold: None,
26    workgroup_size: None,
27    accepts_nan_mode: false,
28    notes: "Executes on the host; GPU-resident inputs and delimiters are gathered before concatenation.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::join")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33    name: "join",
34    shape: ShapeRequirements::Any,
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    elementwise: None,
37    reduction: None,
38    emits_nan: false,
39    notes: "Joins operate on CPU-managed text and are ineligible for fusion.",
40};
41
42const BUILTIN_NAME: &str = "join";
43const INPUT_TYPE_ERROR: &str =
44    "join: input must be a string array, string scalar, character array, or cell array of character vectors";
45const DELIMITER_TYPE_ERROR: &str =
46    "join: delimiter must be a string, character vector, string array, or cell array of character vectors";
47const DELIMITER_SIZE_ERROR: &str =
48    "join: size of delimiter array must match the size of str, with the join dimension reduced by one";
49const DIMENSION_TYPE_ERROR: &str = "join: dimension must be a positive integer scalar";
50
51fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
52    build_runtime_error(message)
53        .with_builtin(BUILTIN_NAME)
54        .build()
55}
56
57fn map_flow(err: RuntimeError) -> RuntimeError {
58    map_control_flow_with_builtin(err, BUILTIN_NAME)
59}
60
61#[runtime_builtin(
62    name = "join",
63    category = "strings/transform",
64    summary = "Combine text across a specified dimension inserting delimiters between elements.",
65    keywords = "join,string join,concatenate strings,delimiters,cell array join",
66    accel = "none",
67    type_resolver(text_concat_type),
68    builtin_path = "crate::builtins::strings::transform::join"
69)]
70async fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
71    let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
72    let mut args = Vec::with_capacity(rest.len());
73    for arg in rest {
74        args.push(gather_if_needed_async(&arg).await.map_err(map_flow)?);
75    }
76
77    let mut input = JoinInput::from_value(text)?;
78    let (delimiter_arg, dimension_arg) = parse_arguments(&args)?;
79
80    let mut shape = input.shape.clone();
81    if shape.is_empty() {
82        shape = vec![1, 1];
83    }
84
85    let default_dim = default_dimension(&shape);
86    let dimension = match dimension_arg {
87        Some(dim) => dim,
88        None => default_dim,
89    };
90
91    if dimension == 0 {
92        return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
93    }
94
95    let ndims = input.ndims();
96    if dimension > ndims {
97        return input.into_value();
98    }
99
100    let axis_idx = dimension - 1;
101    input.ensure_shape_len(dimension);
102    let full_shape = input.shape.clone();
103
104    let delimiter = Delimiter::from_value(delimiter_arg, &full_shape, axis_idx)?;
105
106    let (output_data, output_shape) = perform_join(&input.data, &full_shape, axis_idx, &delimiter);
107
108    input.build_output(output_data, output_shape)
109}
110
111fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<Value>, Option<usize>)> {
112    match args.len() {
113        0 => Ok((None, None)),
114        1 => {
115            if let Some(dim) = value_to_dimension(&args[0])? {
116                Ok((None, Some(dim)))
117            } else {
118                Ok((Some(args[0].clone()), None))
119            }
120        }
121        2 => {
122            if let Some(dim) = value_to_dimension(&args[1])? {
123                Ok((Some(args[0].clone()), Some(dim)))
124            } else if let Some(dim) = value_to_dimension(&args[0])? {
125                Ok((Some(args[1].clone()), Some(dim)))
126            } else {
127                Err(runtime_error_for(DIMENSION_TYPE_ERROR))
128            }
129        }
130        _ => Err(runtime_error_for("join: too many input arguments")),
131    }
132}
133
134fn default_dimension(shape: &[usize]) -> usize {
135    for (index, size) in shape.iter().enumerate().rev() {
136        if *size != 1 {
137            return index + 1;
138        }
139    }
140    2
141}
142
143fn value_to_dimension(value: &Value) -> BuiltinResult<Option<usize>> {
144    match value {
145        Value::Int(i) => {
146            let v = i.to_i64();
147            if v <= 0 {
148                return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
149            }
150            Ok(Some(v as usize))
151        }
152        Value::Num(n) => {
153            if !n.is_finite() || *n <= 0.0 {
154                return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
155            }
156            let rounded = n.round();
157            if (rounded - n).abs() > f64::EPSILON {
158                return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
159            }
160            Ok(Some(rounded as usize))
161        }
162        Value::Tensor(t) if t.data.len() == 1 => {
163            let val = t.data[0];
164            if !val.is_finite() || val <= 0.0 {
165                return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
166            }
167            let rounded = val.round();
168            if (rounded - val).abs() > f64::EPSILON {
169                return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
170            }
171            Ok(Some(rounded as usize))
172        }
173        _ => Ok(None),
174    }
175}
176
177struct JoinInput {
178    data: Vec<String>,
179    shape: Vec<usize>,
180    kind: OutputKind,
181}
182
183#[derive(Clone)]
184enum OutputKind {
185    StringScalar,
186    StringArray,
187    CellArray,
188}
189
190impl JoinInput {
191    fn from_value(value: Value) -> BuiltinResult<Self> {
192        match value {
193            Value::String(text) => Ok(Self {
194                data: vec![text],
195                shape: vec![1, 1],
196                kind: OutputKind::StringScalar,
197            }),
198            Value::StringArray(array) => Ok(Self {
199                data: array.data,
200                shape: array.shape,
201                kind: OutputKind::StringArray,
202            }),
203            Value::CharArray(array) => {
204                let strings = char_array_rows_to_strings(&array);
205                Ok(Self {
206                    data: strings,
207                    shape: vec![array.rows, 1],
208                    kind: OutputKind::StringArray,
209                })
210            }
211            Value::Cell(cell) => {
212                let (data, shape) = cell_array_to_strings(cell)?;
213                Ok(Self {
214                    data,
215                    shape,
216                    kind: OutputKind::CellArray,
217                })
218            }
219            _ => Err(runtime_error_for(INPUT_TYPE_ERROR)),
220        }
221    }
222
223    fn ndims(&self) -> usize {
224        if self.shape.is_empty() {
225            2
226        } else {
227            self.shape.len().max(2)
228        }
229    }
230
231    fn ensure_shape_len(&mut self, dimension: usize) {
232        if self.shape.len() < dimension {
233            self.shape.resize(dimension, 1);
234        }
235    }
236
237    fn into_value(self) -> BuiltinResult<Value> {
238        build_value(self.kind, self.data, self.shape)
239    }
240
241    fn build_output(&self, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
242        build_value(self.kind.clone(), data, shape)
243    }
244}
245
246fn build_value(kind: OutputKind, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
247    match kind {
248        OutputKind::StringScalar => Ok(Value::String(data.into_iter().next().unwrap_or_default())),
249        OutputKind::StringArray => {
250            let array = StringArray::new(data, shape)
251                .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
252            Ok(Value::StringArray(array))
253        }
254        OutputKind::CellArray => {
255            let rows = shape.first().copied().unwrap_or(0);
256            let cols = shape.get(1).copied().unwrap_or(1);
257            if rows == 0 || cols == 0 || data.is_empty() {
258                return make_cell(Vec::new(), rows, cols)
259                    .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")));
260            }
261            let mut values = Vec::with_capacity(rows * cols);
262            for row in 0..rows {
263                for col in 0..cols {
264                    let idx = row + col * rows;
265                    let text = data[idx].clone();
266                    let chars: Vec<char> = text.chars().collect();
267                    let cols_count = chars.len();
268                    let char_array = CharArray::new(chars, 1, cols_count)
269                        .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
270                    values.push(Value::CharArray(char_array));
271                }
272            }
273            make_cell(values, rows, cols)
274                .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
275        }
276    }
277}
278
279fn char_array_rows_to_strings(array: &CharArray) -> Vec<String> {
280    let mut strings = Vec::with_capacity(array.rows);
281    for row in 0..array.rows {
282        strings.push(char_row_to_string_slice(&array.data, array.cols, row));
283    }
284    strings
285}
286
287fn cell_array_to_strings(cell: CellArray) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
288    let CellArray {
289        data, rows, cols, ..
290    } = cell;
291    let mut strings = Vec::with_capacity(rows * cols);
292    for col in 0..cols {
293        for row in 0..rows {
294            let idx = row * cols + col;
295            strings.push(
296                cell_element_to_string(&data[idx])
297                    .ok_or_else(|| runtime_error_for(INPUT_TYPE_ERROR))?,
298            );
299        }
300    }
301    Ok((strings, vec![rows, cols]))
302}
303
304fn cell_element_to_string(value: &Value) -> Option<String> {
305    match value {
306        Value::String(text) => Some(text.clone()),
307        Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
308        Value::CharArray(array) if array.rows <= 1 => {
309            if array.rows == 0 {
310                Some(String::new())
311            } else {
312                Some(char_row_to_string_slice(&array.data, array.cols, 0))
313            }
314        }
315        _ => None,
316    }
317}
318
319#[derive(Clone)]
320enum Delimiter {
321    Scalar(String),
322    Array(DelimiterArray),
323}
324
325#[derive(Clone)]
326struct DelimiterArray {
327    data: Vec<String>,
328    shape: Vec<usize>,
329    strides: Vec<usize>,
330}
331
332impl Delimiter {
333    fn from_value(
334        value: Option<Value>,
335        full_shape: &[usize],
336        axis_idx: usize,
337    ) -> BuiltinResult<Self> {
338        match value {
339            None => Ok(Self::Scalar(" ".to_string())),
340            Some(v) => {
341                if let Some(text) = value_to_scalar_string(&v) {
342                    return Ok(Self::Scalar(text));
343                }
344                let (data, shape) = value_to_string_array(v)?;
345                let normalized = normalize_delimiter_shape(shape, full_shape, axis_idx)?;
346                let strides = compute_strides(&normalized);
347                Ok(Self::Array(DelimiterArray {
348                    data,
349                    shape: normalized,
350                    strides,
351                }))
352            }
353        }
354    }
355
356    fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
357        match self {
358            Delimiter::Scalar(text) => text.as_str(),
359            Delimiter::Array(array) => array.value(coords, axis_idx, axis_gap),
360        }
361    }
362}
363
364impl DelimiterArray {
365    fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
366        let mut offset = 0usize;
367        for (dim, stride) in self.strides.iter().enumerate() {
368            let size = self.shape[dim];
369            let coord = if dim == axis_idx {
370                axis_gap.min(size.saturating_sub(1))
371            } else if size == 1 {
372                0
373            } else {
374                coords[dim].min(size.saturating_sub(1))
375            };
376            offset += coord * stride;
377        }
378        &self.data[offset]
379    }
380}
381
382fn value_to_scalar_string(value: &Value) -> Option<String> {
383    match value {
384        Value::String(text) => Some(text.clone()),
385        Value::CharArray(array) if array.rows <= 1 => {
386            if array.rows == 0 {
387                Some(String::new())
388            } else {
389                Some(char_row_to_string_slice(&array.data, array.cols, 0))
390            }
391        }
392        Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
393        Value::Cell(cell) if cell.data.len() == 1 => cell_element_to_string(&cell.data[0]),
394        _ => None,
395    }
396}
397
398fn value_to_string_array(value: Value) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
399    match value {
400        Value::StringArray(array) => Ok((array.data, array.shape)),
401        Value::Cell(cell) => {
402            let (data, shape) = cell_array_to_strings(cell)?;
403            Ok((data, shape))
404        }
405        Value::CharArray(array) => {
406            let rows = array.rows;
407            let strings = char_array_rows_to_strings(&array);
408            Ok((strings, vec![rows, 1]))
409        }
410        _ => Err(runtime_error_for(DELIMITER_TYPE_ERROR)),
411    }
412}
413
414fn normalize_delimiter_shape(
415    mut shape: Vec<usize>,
416    full_shape: &[usize],
417    axis_idx: usize,
418) -> BuiltinResult<Vec<usize>> {
419    if shape.len() > full_shape.len() {
420        return Err(runtime_error_for(DELIMITER_SIZE_ERROR));
421    }
422    if shape.len() < full_shape.len() {
423        shape.resize(full_shape.len(), 1);
424    }
425
426    let axis_len = full_shape[axis_idx].saturating_sub(1);
427    if axis_len == 0 {
428        shape[axis_idx] = 1;
429    } else if shape[axis_idx] != axis_len {
430        return Err(runtime_error_for(DELIMITER_SIZE_ERROR));
431    }
432
433    for (dim, size) in shape.iter().enumerate() {
434        if dim == axis_idx {
435            continue;
436        }
437        let reference = full_shape[dim];
438        if *size != reference && *size != 1 {
439            return Err(runtime_error_for(DELIMITER_SIZE_ERROR));
440        }
441    }
442
443    Ok(shape)
444}
445
446fn perform_join(
447    data: &[String],
448    full_shape: &[usize],
449    axis_idx: usize,
450    delimiter: &Delimiter,
451) -> (Vec<String>, Vec<usize>) {
452    if full_shape.is_empty() {
453        return (vec![String::new()], vec![1, 1]);
454    }
455
456    let axis_len = full_shape[axis_idx];
457    let mut output_shape = full_shape.to_vec();
458
459    let rest_size = full_shape
460        .iter()
461        .enumerate()
462        .filter(|(idx, _)| *idx != axis_idx)
463        .fold(1usize, |acc, (_, size)| acc.saturating_mul(*size));
464
465    if rest_size == 0 {
466        output_shape[axis_idx] = 0;
467        return (Vec::new(), output_shape);
468    }
469
470    output_shape[axis_idx] = 1;
471
472    let total_output = rest_size;
473    let mut output = Vec::with_capacity(total_output);
474
475    let strides = compute_strides(full_shape);
476    let axis_stride = strides[axis_idx];
477    let dims = full_shape.len();
478    let mut coords = vec![0usize; dims];
479
480    for _ in 0..rest_size {
481        let mut base_offset = 0usize;
482        for dim in 0..dims {
483            base_offset += coords[dim] * strides[dim];
484        }
485
486        if axis_len == 0 {
487            output.push(String::new());
488        } else {
489            let mut result = String::new();
490            let mut missing = false;
491            for axis_pos in 0..axis_len {
492                let element_offset = base_offset + axis_pos * axis_stride;
493                let value = &data[element_offset];
494                if is_missing_string(value) {
495                    missing = true;
496                    break;
497                }
498                if axis_pos > 0 {
499                    let gap = axis_pos - 1;
500                    let delim = delimiter.value(&coords, axis_idx, gap);
501                    result.push_str(delim);
502                }
503                result.push_str(value);
504            }
505            if missing {
506                output.push("<missing>".to_string());
507            } else {
508                output.push(result);
509            }
510        }
511
512        increment_coords(&mut coords, full_shape, axis_idx);
513    }
514
515    (output, output_shape)
516}
517
518fn compute_strides(shape: &[usize]) -> Vec<usize> {
519    let mut strides = vec![1usize; shape.len()];
520    for dim in 1..shape.len() {
521        strides[dim] = strides[dim - 1].saturating_mul(shape[dim - 1]);
522    }
523    strides
524}
525
526fn increment_coords(coords: &mut [usize], shape: &[usize], axis_idx: usize) {
527    for dim in 0..shape.len() {
528        if dim == axis_idx {
529            continue;
530        }
531        coords[dim] += 1;
532        if coords[dim] < shape[dim] {
533            break;
534        }
535        coords[dim] = 0;
536    }
537}
538
539#[cfg(test)]
540pub(crate) mod tests {
541    use super::*;
542    #[cfg(feature = "wgpu")]
543    use runmat_accelerate::backend::wgpu::provider as wgpu_backend;
544    use runmat_builtins::{IntValue, ResolveContext, Type};
545
546    fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
547        futures::executor::block_on(super::join_builtin(text, rest))
548    }
549
550    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
551    #[test]
552    fn join_string_array_default_dimension() {
553        let array = StringArray::new(
554            vec![
555                "Carlos".into(),
556                "Ella".into(),
557                "Diana".into(),
558                "Sada".into(),
559                "Olsen".into(),
560                "Lee".into(),
561            ],
562            vec![3, 2],
563        )
564        .unwrap();
565        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
566        match result {
567            Value::StringArray(sa) => {
568                assert_eq!(sa.shape, vec![3, 1]);
569                assert_eq!(
570                    sa.data,
571                    vec![
572                        "Carlos Sada".to_string(),
573                        "Ella Olsen".to_string(),
574                        "Diana Lee".to_string()
575                    ]
576                );
577            }
578            other => panic!("expected string array, got {other:?}"),
579        }
580    }
581
582    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
583    #[test]
584    fn join_with_custom_scalar_delimiter() {
585        let array = StringArray::new(
586            vec![
587                "x".into(),
588                "a".into(),
589                "y".into(),
590                "b".into(),
591                "z".into(),
592                "c".into(),
593            ],
594            vec![2, 3],
595        )
596        .unwrap();
597        let result =
598            join_builtin(Value::StringArray(array), vec![Value::String("-".into())]).expect("join");
599        match result {
600            Value::StringArray(sa) => {
601                assert_eq!(sa.shape, vec![2, 1]);
602                assert_eq!(sa.data, vec![String::from("x-y-z"), String::from("a-b-c")]);
603            }
604            other => panic!("expected string array, got {other:?}"),
605        }
606    }
607
608    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
609    #[test]
610    fn join_with_delimiter_array_per_row() {
611        let array = StringArray::new(
612            vec![
613                "x".into(),
614                "a".into(),
615                "y".into(),
616                "b".into(),
617                "z".into(),
618                "c".into(),
619            ],
620            vec![2, 3],
621        )
622        .unwrap();
623        let delims = StringArray::new(
624            vec![" + ".into(), " - ".into(), " = ".into(), " = ".into()],
625            vec![2, 2],
626        )
627        .unwrap();
628        let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)])
629            .expect("join");
630        match result {
631            Value::StringArray(sa) => {
632                assert_eq!(sa.shape, vec![2, 1]);
633                assert_eq!(
634                    sa.data,
635                    vec![String::from("x + y = z"), String::from("a - b = c")]
636                );
637            }
638            other => panic!("expected string array, got {other:?}"),
639        }
640    }
641
642    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
643    #[test]
644    fn join_with_dimension_argument() {
645        let array = StringArray::new(
646            vec![
647                "Carlos".into(),
648                "Ella".into(),
649                "Diana".into(),
650                "Sada".into(),
651                "Olsen".into(),
652                "Lee".into(),
653            ],
654            vec![3, 2],
655        )
656        .unwrap();
657        let result = join_builtin(
658            Value::StringArray(array),
659            vec![Value::Int(IntValue::I32(1))],
660        )
661        .expect("join");
662        match result {
663            Value::StringArray(sa) => {
664                assert_eq!(sa.shape, vec![1, 2]);
665                assert_eq!(
666                    sa.data,
667                    vec![
668                        String::from("Carlos Ella Diana"),
669                        String::from("Sada Olsen Lee"),
670                    ]
671                );
672            }
673            other => panic!("expected string array, got {other:?}"),
674        }
675    }
676
677    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
678    #[test]
679    fn join_dimension_greater_than_ndims_returns_input() {
680        let array = StringArray::new(vec!["a".into(), "b".into()], vec![1, 2]).unwrap();
681        let result = join_builtin(
682            Value::StringArray(array.clone()),
683            vec![Value::Int(IntValue::I32(4))],
684        )
685        .expect("join");
686        match result {
687            Value::StringArray(sa) => {
688                assert_eq!(sa.shape, array.shape);
689                assert_eq!(sa.data, array.data);
690            }
691            other => panic!("expected original array, got {other:?}"),
692        }
693    }
694
695    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
696    #[test]
697    fn join_cell_array_of_char_vectors() {
698        let gpu = CharArray::new_row("GPU");
699        let accel = CharArray::new_row("Accelerate");
700        let ignition = CharArray::new_row("Ignition");
701        let interpreter = CharArray::new_row("Interpreter");
702        let values = vec![
703            Value::CharArray(gpu),
704            Value::CharArray(accel),
705            Value::CharArray(ignition),
706            Value::CharArray(interpreter),
707        ];
708        let cell = make_cell(values, 2, 2).expect("cell");
709        let result = join_builtin(cell, vec![Value::String(", ".into())]).expect("join cell");
710        match result {
711            Value::Cell(cell_out) => {
712                assert_eq!(cell_out.rows, 2);
713                assert_eq!(cell_out.cols, 1);
714                let first = unsafe { &*cell_out.data[0].as_raw() };
715                let second = unsafe { &*cell_out.data[1].as_raw() };
716                match (first, second) {
717                    (Value::CharArray(a), Value::CharArray(b)) => {
718                        assert_eq!(
719                            char_row_to_string_slice(&a.data, a.cols, 0),
720                            "GPU, Accelerate"
721                        );
722                        assert_eq!(
723                            char_row_to_string_slice(&b.data, b.cols, 0),
724                            "Ignition, Interpreter"
725                        );
726                    }
727                    other => panic!("expected char arrays, got {other:?}"),
728                }
729            }
730            other => panic!("expected cell array, got {other:?}"),
731        }
732    }
733
734    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
735    #[test]
736    fn join_with_numeric_second_argument_uses_default_delimiter() {
737        let array = StringArray::new(
738            vec!["RunMat".into(), "Accelerate".into(), "Planner".into()],
739            vec![3, 1],
740        )
741        .unwrap();
742        let result = join_builtin(
743            Value::StringArray(array),
744            vec![Value::Int(IntValue::I32(1))],
745        )
746        .expect("join");
747        match result {
748            Value::StringArray(sa) => {
749                assert_eq!(sa.shape, vec![1, 1]);
750                assert_eq!(sa.data, vec![String::from("RunMat Accelerate Planner")]);
751            }
752            other => panic!("expected string array, got {other:?}"),
753        }
754    }
755
756    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
757    #[test]
758    fn join_char_array_input_produces_string_array() {
759        let data: Vec<char> = "RunMatGPUDev".chars().collect();
760        let char_array = CharArray::new(data, 3, 4).unwrap();
761        let result = join_builtin(Value::CharArray(char_array), Vec::new()).expect("join");
762        match result {
763            Value::StringArray(sa) => {
764                assert_eq!(sa.shape, vec![1, 1]);
765                assert_eq!(sa.data, vec![String::from("RunM atGP UDev")]);
766            }
767            other => panic!("expected string array, got {other:?}"),
768        }
769    }
770
771    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
772    #[test]
773    fn join_with_cell_delimiter_array() {
774        let array = StringArray::new(
775            vec![
776                "g".into(),
777                "c".into(),
778                "w".into(),
779                "gpu".into(),
780                "cuda".into(),
781                "wgpu".into(),
782            ],
783            vec![3, 2],
784        )
785        .unwrap();
786        let delimiters = make_cell(
787            vec![
788                Value::String(String::from(" -> ")),
789                Value::String(String::from(" => ")),
790                Value::String(String::from(" :: ")),
791            ],
792            3,
793            1,
794        )
795        .expect("cell");
796        let result = join_builtin(
797            Value::StringArray(array),
798            vec![delimiters, Value::Int(IntValue::I32(2))],
799        )
800        .expect("join");
801        match result {
802            Value::StringArray(sa) => {
803                assert_eq!(sa.shape, vec![3, 1]);
804                assert_eq!(
805                    sa.data,
806                    vec![
807                        String::from("g -> gpu"),
808                        String::from("c => cuda"),
809                        String::from("w :: wgpu")
810                    ]
811                );
812            }
813            other => panic!("expected string array, got {other:?}"),
814        }
815    }
816
817    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
818    #[test]
819    fn join_3d_string_array_along_third_dimension() {
820        let mut data = Vec::new();
821        for page in 0..2 {
822            for col in 0..2 {
823                for row in 0..2 {
824                    data.push(format!("r{row}c{col}p{page}"));
825                }
826            }
827        }
828        let array = StringArray::new(data, vec![2, 2, 2]).unwrap();
829        let result = join_builtin(
830            Value::StringArray(array),
831            vec![Value::String(":".into()), Value::Int(IntValue::I32(3))],
832        )
833        .expect("join");
834        match result {
835            Value::StringArray(sa) => {
836                assert_eq!(sa.shape, vec![2, 2, 1]);
837                let expected = vec![
838                    String::from("r0c0p0:r0c0p1"),
839                    String::from("r1c0p0:r1c0p1"),
840                    String::from("r0c1p0:r0c1p1"),
841                    String::from("r1c1p0:r1c1p1"),
842                ];
843                assert_eq!(sa.data, expected);
844            }
845            other => panic!("expected string array, got {other:?}"),
846        }
847    }
848
849    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
850    #[test]
851    fn join_errors_on_zero_dimension() {
852        let array = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
853        let err = join_builtin(
854            Value::StringArray(array),
855            vec![Value::Int(IntValue::I32(0))],
856        )
857        .unwrap_err();
858        let err_text = err.to_string();
859        assert!(
860            err_text.contains("dimension"),
861            "expected dimension error, got {err_text}"
862        );
863    }
864
865    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
866    #[test]
867    fn join_errors_on_mismatched_delimiter_shape() {
868        let array = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![1, 3]).unwrap();
869        let delims =
870            StringArray::new(vec!["+".into(), "-".into(), "=".into()], vec![1, 3]).unwrap();
871        let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)]);
872        assert!(result.is_err());
873    }
874
875    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
876    #[test]
877    fn join_propagates_missing_strings() {
878        let array = StringArray::new(vec!["GPU".into(), "<missing>".into()], vec![1, 2]).unwrap();
879        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
880        match result {
881            Value::StringArray(sa) => {
882                assert_eq!(sa.data, vec![String::from("<missing>")]);
883            }
884            other => panic!("expected string array, got {other:?}"),
885        }
886    }
887
888    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
889    #[test]
890    fn join_accepts_char_delimiter_scalar() {
891        let array = StringArray::new(vec!["A".into(), "B".into()], vec![1, 2]).unwrap();
892        let delimiter_chars = CharArray::new("++".chars().collect::<Vec<char>>(), 1, 2).unwrap();
893        let result = join_builtin(
894            Value::StringArray(array),
895            vec![Value::CharArray(delimiter_chars)],
896        )
897        .expect("join");
898        match result {
899            Value::StringArray(sa) => {
900                assert_eq!(sa.data, vec![String::from("A++B")]);
901            }
902            other => panic!("expected string array, got {other:?}"),
903        }
904    }
905
906    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
907    #[test]
908    fn join_handles_empty_axis() {
909        let array = StringArray::new(Vec::new(), vec![2, 0]).unwrap();
910        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
911        match result {
912            Value::StringArray(sa) => {
913                assert_eq!(sa.shape, vec![2, 1]);
914                assert_eq!(sa.data, vec![String::from(""), String::from("")]);
915            }
916            other => panic!("expected string array, got {other:?}"),
917        }
918    }
919
920    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
921    #[test]
922    fn join_missing_dimension_broadcast_delimiters() {
923        let array = StringArray::new(
924            vec!["aa".into(), "cc".into(), "bb".into(), "dd".into()],
925            vec![2, 2],
926        )
927        .unwrap();
928        let delims = StringArray::new(vec!["-".into()], vec![1, 1]).unwrap();
929        let result = join_builtin(
930            Value::StringArray(array),
931            vec![Value::StringArray(delims), Value::Int(IntValue::I32(2))],
932        )
933        .expect("join");
934        match result {
935            Value::StringArray(sa) => {
936                assert_eq!(sa.shape, vec![2, 1]);
937                assert_eq!(sa.data, vec![String::from("aa-bb"), String::from("cc-dd")]);
938            }
939            other => panic!("expected string array, got {other:?}"),
940        }
941    }
942
943    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
944    #[test]
945    #[cfg(feature = "wgpu")]
946    fn join_executes_with_wgpu_provider_registered() {
947        let _ = wgpu_backend::register_wgpu_provider(wgpu_backend::WgpuProviderOptions::default());
948        let array = StringArray::new(vec!["GPU".into(), "Planner".into()], vec![2, 1]).unwrap();
949        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
950        match result {
951            Value::StringArray(sa) => {
952                assert_eq!(sa.data, vec![String::from("GPU Planner")]);
953            }
954            other => panic!("expected string array, got {other:?}"),
955        }
956    }
957
958    #[test]
959    fn join_type_concatenates_text() {
960        assert_eq!(
961            text_concat_type(&[Type::String], &ResolveContext::new(Vec::new())),
962            Type::String
963        );
964    }
965}