Skip to main content

runmat_runtime/builtins/strings/transform/
join.rs

1//! MATLAB-compatible `join` builtin with GPU-aware semantics for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6};
7use runmat_builtins::{CellArray, CharArray, StringArray, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_concat_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::join")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "join",
22    op_kind: GpuOpKind::Custom("string-transform"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::None,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes: "Executes on the host; GPU-resident inputs and delimiters are gathered before concatenation.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::join")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37    name: "join",
38    shape: ShapeRequirements::Any,
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    elementwise: None,
41    reduction: None,
42    emits_nan: false,
43    notes: "Joins operate on CPU-managed text and are ineligible for fusion.",
44};
45
46const BUILTIN_NAME: &str = "join";
47
48const JOIN_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
49    name: "out",
50    ty: BuiltinParamType::Any,
51    arity: BuiltinParamArity::Required,
52    default: None,
53    description: "Joined text preserving join output container semantics.",
54}];
55
56const JOIN_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
57    name: "str",
58    ty: BuiltinParamType::Any,
59    arity: BuiltinParamArity::Required,
60    default: None,
61    description: "Input text (string/char/cell).",
62}];
63
64const JOIN_INPUTS_DELIMITER: [BuiltinParamDescriptor; 2] = [
65    BuiltinParamDescriptor {
66        name: "str",
67        ty: BuiltinParamType::Any,
68        arity: BuiltinParamArity::Required,
69        default: None,
70        description: "Input text (string/char/cell).",
71    },
72    BuiltinParamDescriptor {
73        name: "delimiter",
74        ty: BuiltinParamType::Any,
75        arity: BuiltinParamArity::Required,
76        default: Some("\" \""),
77        description: "Delimiter scalar or delimiter array matching join shape constraints.",
78    },
79];
80
81const JOIN_INPUTS_DIM: [BuiltinParamDescriptor; 2] = [
82    BuiltinParamDescriptor {
83        name: "str",
84        ty: BuiltinParamType::Any,
85        arity: BuiltinParamArity::Required,
86        default: None,
87        description: "Input text (string/char/cell).",
88    },
89    BuiltinParamDescriptor {
90        name: "dim",
91        ty: BuiltinParamType::IntegerScalar,
92        arity: BuiltinParamArity::Required,
93        default: None,
94        description: "Positive dimension index to join along.",
95    },
96];
97
98const JOIN_INPUTS_DELIMITER_DIM: [BuiltinParamDescriptor; 3] = [
99    BuiltinParamDescriptor {
100        name: "str",
101        ty: BuiltinParamType::Any,
102        arity: BuiltinParamArity::Required,
103        default: None,
104        description: "Input text (string/char/cell).",
105    },
106    BuiltinParamDescriptor {
107        name: "delimiter",
108        ty: BuiltinParamType::Any,
109        arity: BuiltinParamArity::Required,
110        default: None,
111        description: "Delimiter scalar or delimiter array matching join shape constraints.",
112    },
113    BuiltinParamDescriptor {
114        name: "dim",
115        ty: BuiltinParamType::IntegerScalar,
116        arity: BuiltinParamArity::Required,
117        default: None,
118        description: "Positive dimension index to join along.",
119    },
120];
121
122const JOIN_INPUTS_DIM_DELIMITER: [BuiltinParamDescriptor; 3] = [
123    BuiltinParamDescriptor {
124        name: "str",
125        ty: BuiltinParamType::Any,
126        arity: BuiltinParamArity::Required,
127        default: None,
128        description: "Input text (string/char/cell).",
129    },
130    BuiltinParamDescriptor {
131        name: "dim",
132        ty: BuiltinParamType::IntegerScalar,
133        arity: BuiltinParamArity::Required,
134        default: None,
135        description: "Positive dimension index to join along.",
136    },
137    BuiltinParamDescriptor {
138        name: "delimiter",
139        ty: BuiltinParamType::Any,
140        arity: BuiltinParamArity::Required,
141        default: None,
142        description: "Delimiter scalar or delimiter array matching join shape constraints.",
143    },
144];
145
146const JOIN_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
147    BuiltinSignatureDescriptor {
148        label: "out = join(str)",
149        inputs: &JOIN_INPUTS_BASE,
150        outputs: &JOIN_OUTPUT,
151    },
152    BuiltinSignatureDescriptor {
153        label: "out = join(str, delimiter)",
154        inputs: &JOIN_INPUTS_DELIMITER,
155        outputs: &JOIN_OUTPUT,
156    },
157    BuiltinSignatureDescriptor {
158        label: "out = join(str, dim)",
159        inputs: &JOIN_INPUTS_DIM,
160        outputs: &JOIN_OUTPUT,
161    },
162    BuiltinSignatureDescriptor {
163        label: "out = join(str, delimiter, dim)",
164        inputs: &JOIN_INPUTS_DELIMITER_DIM,
165        outputs: &JOIN_OUTPUT,
166    },
167    BuiltinSignatureDescriptor {
168        label: "out = join(str, dim, delimiter)",
169        inputs: &JOIN_INPUTS_DIM_DELIMITER,
170        outputs: &JOIN_OUTPUT,
171    },
172];
173
174const JOIN_ERROR_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
175    code: "RM.JOIN.INPUT_TYPE",
176    identifier: Some("RunMat:join:InputType"),
177    when: "Input text is not a string array/scalar, char array, or cell array of text scalars.",
178    message:
179        "join: input must be a string array, string scalar, character array, or cell array of character vectors",
180};
181
182const JOIN_ERROR_DELIMITER_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
183    code: "RM.JOIN.DELIMITER_TYPE",
184    identifier: Some("RunMat:join:DelimiterType"),
185    when: "Delimiter is not a supported text scalar/array/cell value.",
186    message:
187        "join: delimiter must be a string, character vector, string array, or cell array of character vectors",
188};
189
190const JOIN_ERROR_DELIMITER_SIZE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
191    code: "RM.JOIN.DELIMITER_SIZE",
192    identifier: Some("RunMat:join:DelimiterSize"),
193    when: "Delimiter array shape does not match join shape constraints.",
194    message:
195        "join: size of delimiter array must match the size of str, with the join dimension reduced by one",
196};
197
198const JOIN_ERROR_DIMENSION_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
199    code: "RM.JOIN.DIMENSION_TYPE",
200    identifier: Some("RunMat:join:DimensionType"),
201    when: "Dimension argument is not a positive integer scalar.",
202    message: "join: dimension must be a positive integer scalar",
203};
204
205const JOIN_ERROR_ARG_COUNT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
206    code: "RM.JOIN.ARG_COUNT",
207    identifier: Some("RunMat:join:ArgCount"),
208    when: "More than three total arguments are supplied.",
209    message: "join: too many input arguments",
210};
211
212const JOIN_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
213    code: "RM.JOIN.INTERNAL",
214    identifier: Some("RunMat:join:InternalError"),
215    when: "Internal output container construction failed.",
216    message: "join: internal error",
217};
218
219const JOIN_ERRORS: [BuiltinErrorDescriptor; 6] = [
220    JOIN_ERROR_INPUT_TYPE,
221    JOIN_ERROR_DELIMITER_TYPE,
222    JOIN_ERROR_DELIMITER_SIZE,
223    JOIN_ERROR_DIMENSION_TYPE,
224    JOIN_ERROR_ARG_COUNT,
225    JOIN_ERROR_INTERNAL,
226];
227
228pub const JOIN_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
229    signatures: &JOIN_SIGNATURES,
230    output_mode: BuiltinOutputMode::Fixed,
231    completion_policy: BuiltinCompletionPolicy::Public,
232    errors: &JOIN_ERRORS,
233};
234
235fn map_flow(err: RuntimeError) -> RuntimeError {
236    map_control_flow_with_builtin(err, BUILTIN_NAME)
237}
238
239fn join_error_with_message(
240    message: impl Into<String>,
241    error: &'static BuiltinErrorDescriptor,
242) -> RuntimeError {
243    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
244    if let Some(identifier) = error.identifier {
245        builder = builder.with_identifier(identifier);
246    }
247    builder.build()
248}
249
250fn join_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
251    join_error_with_message(error.message, error)
252}
253
254#[runtime_builtin(
255    name = "join",
256    category = "strings/transform",
257    summary = "Join text elements with delimiters along a dimension.",
258    keywords = "join,string join,concatenate strings,delimiters,cell array join",
259    accel = "none",
260    type_resolver(text_concat_type),
261    descriptor(crate::builtins::strings::transform::join::JOIN_DESCRIPTOR),
262    builtin_path = "crate::builtins::strings::transform::join"
263)]
264async fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
265    let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
266    let mut args = Vec::with_capacity(rest.len());
267    for arg in rest {
268        args.push(gather_if_needed_async(&arg).await.map_err(map_flow)?);
269    }
270
271    let mut input = JoinInput::from_value(text)?;
272    let (delimiter_arg, dimension_arg) = parse_arguments(&args)?;
273
274    let mut shape = input.shape.clone();
275    if shape.is_empty() {
276        shape = vec![1, 1];
277    }
278
279    let default_dim = default_dimension(&shape);
280    let dimension = match dimension_arg {
281        Some(dim) => dim,
282        None => default_dim,
283    };
284
285    if dimension == 0 {
286        return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
287    }
288
289    let ndims = input.ndims();
290    if dimension > ndims {
291        return input.into_value();
292    }
293
294    let axis_idx = dimension - 1;
295    input.ensure_shape_len(dimension);
296    let full_shape = input.shape.clone();
297
298    let delimiter = Delimiter::from_value(delimiter_arg, &full_shape, axis_idx)?;
299
300    let (output_data, output_shape) = perform_join(&input.data, &full_shape, axis_idx, &delimiter);
301
302    input.build_output(output_data, output_shape)
303}
304
305fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<Value>, Option<usize>)> {
306    match args.len() {
307        0 => Ok((None, None)),
308        1 => {
309            if let Some(dim) = value_to_dimension(&args[0])? {
310                Ok((None, Some(dim)))
311            } else {
312                Ok((Some(args[0].clone()), None))
313            }
314        }
315        2 => {
316            if let Some(dim) = value_to_dimension(&args[1])? {
317                Ok((Some(args[0].clone()), Some(dim)))
318            } else if let Some(dim) = value_to_dimension(&args[0])? {
319                Ok((Some(args[1].clone()), Some(dim)))
320            } else {
321                Err(join_error(&JOIN_ERROR_DIMENSION_TYPE))
322            }
323        }
324        _ => Err(join_error(&JOIN_ERROR_ARG_COUNT)),
325    }
326}
327
328fn default_dimension(shape: &[usize]) -> usize {
329    for (index, size) in shape.iter().enumerate().rev() {
330        if *size != 1 {
331            return index + 1;
332        }
333    }
334    2
335}
336
337fn value_to_dimension(value: &Value) -> BuiltinResult<Option<usize>> {
338    match value {
339        Value::Int(i) => {
340            let v = i.to_i64();
341            if v <= 0 {
342                return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
343            }
344            Ok(Some(v as usize))
345        }
346        Value::Num(n) => {
347            if !n.is_finite() || *n <= 0.0 {
348                return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
349            }
350            let rounded = n.round();
351            if (rounded - n).abs() > f64::EPSILON {
352                return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
353            }
354            Ok(Some(rounded as usize))
355        }
356        Value::Tensor(t) if t.data.len() == 1 => {
357            let val = t.data[0];
358            if !val.is_finite() || val <= 0.0 {
359                return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
360            }
361            let rounded = val.round();
362            if (rounded - val).abs() > f64::EPSILON {
363                return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
364            }
365            Ok(Some(rounded as usize))
366        }
367        _ => Ok(None),
368    }
369}
370
371struct JoinInput {
372    data: Vec<String>,
373    shape: Vec<usize>,
374    kind: OutputKind,
375}
376
377#[derive(Clone)]
378enum OutputKind {
379    StringScalar,
380    StringArray,
381    CellArray,
382}
383
384impl JoinInput {
385    fn from_value(value: Value) -> BuiltinResult<Self> {
386        match value {
387            Value::String(text) => Ok(Self {
388                data: vec![text],
389                shape: vec![1, 1],
390                kind: OutputKind::StringScalar,
391            }),
392            Value::StringArray(array) => Ok(Self {
393                data: array.data,
394                shape: array.shape,
395                kind: OutputKind::StringArray,
396            }),
397            Value::CharArray(array) => {
398                let strings = char_array_rows_to_strings(&array);
399                Ok(Self {
400                    data: strings,
401                    shape: vec![array.rows, 1],
402                    kind: OutputKind::StringArray,
403                })
404            }
405            Value::Cell(cell) => {
406                let (data, shape) = cell_array_to_strings(cell)?;
407                Ok(Self {
408                    data,
409                    shape,
410                    kind: OutputKind::CellArray,
411                })
412            }
413            _ => Err(join_error(&JOIN_ERROR_INPUT_TYPE)),
414        }
415    }
416
417    fn ndims(&self) -> usize {
418        if self.shape.is_empty() {
419            2
420        } else {
421            self.shape.len().max(2)
422        }
423    }
424
425    fn ensure_shape_len(&mut self, dimension: usize) {
426        if self.shape.len() < dimension {
427            self.shape.resize(dimension, 1);
428        }
429    }
430
431    fn into_value(self) -> BuiltinResult<Value> {
432        build_value(self.kind, self.data, self.shape)
433    }
434
435    fn build_output(&self, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
436        build_value(self.kind.clone(), data, shape)
437    }
438}
439
440fn build_value(kind: OutputKind, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
441    match kind {
442        OutputKind::StringScalar => Ok(Value::String(data.into_iter().next().unwrap_or_default())),
443        OutputKind::StringArray => {
444            let array = StringArray::new(data, shape).map_err(|e| {
445                join_error_with_message(format!("{BUILTIN_NAME}: {e}"), &JOIN_ERROR_INTERNAL)
446            })?;
447            Ok(Value::StringArray(array))
448        }
449        OutputKind::CellArray => {
450            let rows = shape.first().copied().unwrap_or(0);
451            let cols = shape.get(1).copied().unwrap_or(1);
452            if rows == 0 || cols == 0 || data.is_empty() {
453                return make_cell(Vec::new(), rows, cols).map_err(|e| {
454                    join_error_with_message(format!("{BUILTIN_NAME}: {e}"), &JOIN_ERROR_INTERNAL)
455                });
456            }
457            let mut values = Vec::with_capacity(rows * cols);
458            for row in 0..rows {
459                for col in 0..cols {
460                    let idx = row + col * rows;
461                    let text = data[idx].clone();
462                    let chars: Vec<char> = text.chars().collect();
463                    let cols_count = chars.len();
464                    let char_array = CharArray::new(chars, 1, cols_count).map_err(|e| {
465                        join_error_with_message(
466                            format!("{BUILTIN_NAME}: {e}"),
467                            &JOIN_ERROR_INTERNAL,
468                        )
469                    })?;
470                    values.push(Value::CharArray(char_array));
471                }
472            }
473            make_cell(values, rows, cols).map_err(|e| {
474                join_error_with_message(format!("{BUILTIN_NAME}: {e}"), &JOIN_ERROR_INTERNAL)
475            })
476        }
477    }
478}
479
480fn char_array_rows_to_strings(array: &CharArray) -> Vec<String> {
481    let mut strings = Vec::with_capacity(array.rows);
482    for row in 0..array.rows {
483        strings.push(char_row_to_string_slice(&array.data, array.cols, row));
484    }
485    strings
486}
487
488fn cell_array_to_strings(cell: CellArray) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
489    let CellArray {
490        data, rows, cols, ..
491    } = cell;
492    let mut strings = Vec::with_capacity(rows * cols);
493    for col in 0..cols {
494        for row in 0..rows {
495            let idx = row * cols + col;
496            strings.push(
497                cell_element_to_string(&data[idx])
498                    .ok_or_else(|| join_error(&JOIN_ERROR_INPUT_TYPE))?,
499            );
500        }
501    }
502    Ok((strings, vec![rows, cols]))
503}
504
505fn cell_element_to_string(value: &Value) -> Option<String> {
506    match value {
507        Value::String(text) => Some(text.clone()),
508        Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
509        Value::CharArray(array) if array.rows <= 1 => {
510            if array.rows == 0 {
511                Some(String::new())
512            } else {
513                Some(char_row_to_string_slice(&array.data, array.cols, 0))
514            }
515        }
516        _ => None,
517    }
518}
519
520#[derive(Clone)]
521enum Delimiter {
522    Scalar(String),
523    Array(DelimiterArray),
524}
525
526#[derive(Clone)]
527struct DelimiterArray {
528    data: Vec<String>,
529    shape: Vec<usize>,
530    strides: Vec<usize>,
531}
532
533impl Delimiter {
534    fn from_value(
535        value: Option<Value>,
536        full_shape: &[usize],
537        axis_idx: usize,
538    ) -> BuiltinResult<Self> {
539        match value {
540            None => Ok(Self::Scalar(" ".to_string())),
541            Some(v) => {
542                if let Some(text) = value_to_scalar_string(&v) {
543                    return Ok(Self::Scalar(text));
544                }
545                let (data, shape) = value_to_string_array(v)?;
546                let normalized = normalize_delimiter_shape(shape, full_shape, axis_idx)?;
547                let strides = compute_strides(&normalized);
548                Ok(Self::Array(DelimiterArray {
549                    data,
550                    shape: normalized,
551                    strides,
552                }))
553            }
554        }
555    }
556
557    fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
558        match self {
559            Delimiter::Scalar(text) => text.as_str(),
560            Delimiter::Array(array) => array.value(coords, axis_idx, axis_gap),
561        }
562    }
563}
564
565impl DelimiterArray {
566    fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
567        let mut offset = 0usize;
568        for (dim, stride) in self.strides.iter().enumerate() {
569            let size = self.shape[dim];
570            let coord = if dim == axis_idx {
571                axis_gap.min(size.saturating_sub(1))
572            } else if size == 1 {
573                0
574            } else {
575                coords[dim].min(size.saturating_sub(1))
576            };
577            offset += coord * stride;
578        }
579        &self.data[offset]
580    }
581}
582
583fn value_to_scalar_string(value: &Value) -> Option<String> {
584    match value {
585        Value::String(text) => Some(text.clone()),
586        Value::CharArray(array) if array.rows <= 1 => {
587            if array.rows == 0 {
588                Some(String::new())
589            } else {
590                Some(char_row_to_string_slice(&array.data, array.cols, 0))
591            }
592        }
593        Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
594        Value::Cell(cell) if cell.data.len() == 1 => cell_element_to_string(&cell.data[0]),
595        _ => None,
596    }
597}
598
599fn value_to_string_array(value: Value) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
600    match value {
601        Value::StringArray(array) => Ok((array.data, array.shape)),
602        Value::Cell(cell) => {
603            let (data, shape) = cell_array_to_strings(cell)?;
604            Ok((data, shape))
605        }
606        Value::CharArray(array) => {
607            let rows = array.rows;
608            let strings = char_array_rows_to_strings(&array);
609            Ok((strings, vec![rows, 1]))
610        }
611        _ => Err(join_error(&JOIN_ERROR_DELIMITER_TYPE)),
612    }
613}
614
615fn normalize_delimiter_shape(
616    mut shape: Vec<usize>,
617    full_shape: &[usize],
618    axis_idx: usize,
619) -> BuiltinResult<Vec<usize>> {
620    if shape.len() > full_shape.len() {
621        return Err(join_error(&JOIN_ERROR_DELIMITER_SIZE));
622    }
623    if shape.len() < full_shape.len() {
624        shape.resize(full_shape.len(), 1);
625    }
626
627    let axis_len = full_shape[axis_idx].saturating_sub(1);
628    if axis_len == 0 {
629        shape[axis_idx] = 1;
630    } else if shape[axis_idx] != axis_len {
631        return Err(join_error(&JOIN_ERROR_DELIMITER_SIZE));
632    }
633
634    for (dim, size) in shape.iter().enumerate() {
635        if dim == axis_idx {
636            continue;
637        }
638        let reference = full_shape[dim];
639        if *size != reference && *size != 1 {
640            return Err(join_error(&JOIN_ERROR_DELIMITER_SIZE));
641        }
642    }
643
644    Ok(shape)
645}
646
647fn perform_join(
648    data: &[String],
649    full_shape: &[usize],
650    axis_idx: usize,
651    delimiter: &Delimiter,
652) -> (Vec<String>, Vec<usize>) {
653    if full_shape.is_empty() {
654        return (vec![String::new()], vec![1, 1]);
655    }
656
657    let axis_len = full_shape[axis_idx];
658    let mut output_shape = full_shape.to_vec();
659
660    let rest_size = full_shape
661        .iter()
662        .enumerate()
663        .filter(|(idx, _)| *idx != axis_idx)
664        .fold(1usize, |acc, (_, size)| acc.saturating_mul(*size));
665
666    if rest_size == 0 {
667        output_shape[axis_idx] = 0;
668        return (Vec::new(), output_shape);
669    }
670
671    output_shape[axis_idx] = 1;
672
673    let total_output = rest_size;
674    let mut output = Vec::with_capacity(total_output);
675
676    let strides = compute_strides(full_shape);
677    let axis_stride = strides[axis_idx];
678    let dims = full_shape.len();
679    let mut coords = vec![0usize; dims];
680
681    for _ in 0..rest_size {
682        let mut base_offset = 0usize;
683        for dim in 0..dims {
684            base_offset += coords[dim] * strides[dim];
685        }
686
687        if axis_len == 0 {
688            output.push(String::new());
689        } else {
690            let mut result = String::new();
691            let mut missing = false;
692            for axis_pos in 0..axis_len {
693                let element_offset = base_offset + axis_pos * axis_stride;
694                let value = &data[element_offset];
695                if is_missing_string(value) {
696                    missing = true;
697                    break;
698                }
699                if axis_pos > 0 {
700                    let gap = axis_pos - 1;
701                    let delim = delimiter.value(&coords, axis_idx, gap);
702                    result.push_str(delim);
703                }
704                result.push_str(value);
705            }
706            if missing {
707                output.push("<missing>".to_string());
708            } else {
709                output.push(result);
710            }
711        }
712
713        increment_coords(&mut coords, full_shape, axis_idx);
714    }
715
716    (output, output_shape)
717}
718
719fn compute_strides(shape: &[usize]) -> Vec<usize> {
720    let mut strides = vec![1usize; shape.len()];
721    for dim in 1..shape.len() {
722        strides[dim] = strides[dim - 1].saturating_mul(shape[dim - 1]);
723    }
724    strides
725}
726
727fn increment_coords(coords: &mut [usize], shape: &[usize], axis_idx: usize) {
728    for dim in 0..shape.len() {
729        if dim == axis_idx {
730            continue;
731        }
732        coords[dim] += 1;
733        if coords[dim] < shape[dim] {
734            break;
735        }
736        coords[dim] = 0;
737    }
738}
739
740#[cfg(test)]
741pub(crate) mod tests {
742    use super::*;
743    #[cfg(feature = "wgpu")]
744    use runmat_accelerate::backend::wgpu::provider as wgpu_backend;
745    use runmat_builtins::{IntValue, ResolveContext, Type};
746
747    fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
748        futures::executor::block_on(super::join_builtin(text, rest))
749    }
750
751    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
752    #[test]
753    fn join_string_array_default_dimension() {
754        let array = StringArray::new(
755            vec![
756                "Carlos".into(),
757                "Ella".into(),
758                "Diana".into(),
759                "Sada".into(),
760                "Olsen".into(),
761                "Lee".into(),
762            ],
763            vec![3, 2],
764        )
765        .unwrap();
766        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
767        match result {
768            Value::StringArray(sa) => {
769                assert_eq!(sa.shape, vec![3, 1]);
770                assert_eq!(
771                    sa.data,
772                    vec![
773                        "Carlos Sada".to_string(),
774                        "Ella Olsen".to_string(),
775                        "Diana Lee".to_string()
776                    ]
777                );
778            }
779            other => panic!("expected string array, got {other:?}"),
780        }
781    }
782
783    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
784    #[test]
785    fn join_with_custom_scalar_delimiter() {
786        let array = StringArray::new(
787            vec![
788                "x".into(),
789                "a".into(),
790                "y".into(),
791                "b".into(),
792                "z".into(),
793                "c".into(),
794            ],
795            vec![2, 3],
796        )
797        .unwrap();
798        let result =
799            join_builtin(Value::StringArray(array), vec![Value::String("-".into())]).expect("join");
800        match result {
801            Value::StringArray(sa) => {
802                assert_eq!(sa.shape, vec![2, 1]);
803                assert_eq!(sa.data, vec![String::from("x-y-z"), String::from("a-b-c")]);
804            }
805            other => panic!("expected string array, got {other:?}"),
806        }
807    }
808
809    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
810    #[test]
811    fn join_with_delimiter_array_per_row() {
812        let array = StringArray::new(
813            vec![
814                "x".into(),
815                "a".into(),
816                "y".into(),
817                "b".into(),
818                "z".into(),
819                "c".into(),
820            ],
821            vec![2, 3],
822        )
823        .unwrap();
824        let delims = StringArray::new(
825            vec![" + ".into(), " - ".into(), " = ".into(), " = ".into()],
826            vec![2, 2],
827        )
828        .unwrap();
829        let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)])
830            .expect("join");
831        match result {
832            Value::StringArray(sa) => {
833                assert_eq!(sa.shape, vec![2, 1]);
834                assert_eq!(
835                    sa.data,
836                    vec![String::from("x + y = z"), String::from("a - b = c")]
837                );
838            }
839            other => panic!("expected string array, got {other:?}"),
840        }
841    }
842
843    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
844    #[test]
845    fn join_with_dimension_argument() {
846        let array = StringArray::new(
847            vec![
848                "Carlos".into(),
849                "Ella".into(),
850                "Diana".into(),
851                "Sada".into(),
852                "Olsen".into(),
853                "Lee".into(),
854            ],
855            vec![3, 2],
856        )
857        .unwrap();
858        let result = join_builtin(
859            Value::StringArray(array),
860            vec![Value::Int(IntValue::I32(1))],
861        )
862        .expect("join");
863        match result {
864            Value::StringArray(sa) => {
865                assert_eq!(sa.shape, vec![1, 2]);
866                assert_eq!(
867                    sa.data,
868                    vec![
869                        String::from("Carlos Ella Diana"),
870                        String::from("Sada Olsen Lee"),
871                    ]
872                );
873            }
874            other => panic!("expected string array, got {other:?}"),
875        }
876    }
877
878    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
879    #[test]
880    fn join_dimension_greater_than_ndims_returns_input() {
881        let array = StringArray::new(vec!["a".into(), "b".into()], vec![1, 2]).unwrap();
882        let result = join_builtin(
883            Value::StringArray(array.clone()),
884            vec![Value::Int(IntValue::I32(4))],
885        )
886        .expect("join");
887        match result {
888            Value::StringArray(sa) => {
889                assert_eq!(sa.shape, array.shape);
890                assert_eq!(sa.data, array.data);
891            }
892            other => panic!("expected original array, got {other:?}"),
893        }
894    }
895
896    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
897    #[test]
898    fn join_cell_array_of_char_vectors() {
899        let gpu = CharArray::new_row("GPU");
900        let accel = CharArray::new_row("Accelerate");
901        let vm_label = CharArray::new_row("VM");
902        let interpreter = CharArray::new_row("Interpreter");
903        let values = vec![
904            Value::CharArray(gpu),
905            Value::CharArray(accel),
906            Value::CharArray(vm_label),
907            Value::CharArray(interpreter),
908        ];
909        let cell = make_cell(values, 2, 2).expect("cell");
910        let result = join_builtin(cell, vec![Value::String(", ".into())]).expect("join cell");
911        match result {
912            Value::Cell(cell_out) => {
913                assert_eq!(cell_out.rows, 2);
914                assert_eq!(cell_out.cols, 1);
915                let first = unsafe { &*cell_out.data[0].as_raw() };
916                let second = unsafe { &*cell_out.data[1].as_raw() };
917                match (first, second) {
918                    (Value::CharArray(a), Value::CharArray(b)) => {
919                        assert_eq!(
920                            char_row_to_string_slice(&a.data, a.cols, 0),
921                            "GPU, Accelerate"
922                        );
923                        assert_eq!(
924                            char_row_to_string_slice(&b.data, b.cols, 0),
925                            "VM, Interpreter"
926                        );
927                    }
928                    other => panic!("expected char arrays, got {other:?}"),
929                }
930            }
931            other => panic!("expected cell array, got {other:?}"),
932        }
933    }
934
935    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
936    #[test]
937    fn join_with_numeric_second_argument_uses_default_delimiter() {
938        let array = StringArray::new(
939            vec!["RunMat".into(), "Accelerate".into(), "Planner".into()],
940            vec![3, 1],
941        )
942        .unwrap();
943        let result = join_builtin(
944            Value::StringArray(array),
945            vec![Value::Int(IntValue::I32(1))],
946        )
947        .expect("join");
948        match result {
949            Value::StringArray(sa) => {
950                assert_eq!(sa.shape, vec![1, 1]);
951                assert_eq!(sa.data, vec![String::from("RunMat Accelerate Planner")]);
952            }
953            other => panic!("expected string array, got {other:?}"),
954        }
955    }
956
957    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
958    #[test]
959    fn join_char_array_input_produces_string_array() {
960        let data: Vec<char> = "RunMatGPUDev".chars().collect();
961        let char_array = CharArray::new(data, 3, 4).unwrap();
962        let result = join_builtin(Value::CharArray(char_array), Vec::new()).expect("join");
963        match result {
964            Value::StringArray(sa) => {
965                assert_eq!(sa.shape, vec![1, 1]);
966                assert_eq!(sa.data, vec![String::from("RunM atGP UDev")]);
967            }
968            other => panic!("expected string array, got {other:?}"),
969        }
970    }
971
972    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
973    #[test]
974    fn join_with_cell_delimiter_array() {
975        let array = StringArray::new(
976            vec![
977                "g".into(),
978                "c".into(),
979                "w".into(),
980                "gpu".into(),
981                "cuda".into(),
982                "wgpu".into(),
983            ],
984            vec![3, 2],
985        )
986        .unwrap();
987        let delimiters = make_cell(
988            vec![
989                Value::String(String::from(" -> ")),
990                Value::String(String::from(" => ")),
991                Value::String(String::from(" :: ")),
992            ],
993            3,
994            1,
995        )
996        .expect("cell");
997        let result = join_builtin(
998            Value::StringArray(array),
999            vec![delimiters, Value::Int(IntValue::I32(2))],
1000        )
1001        .expect("join");
1002        match result {
1003            Value::StringArray(sa) => {
1004                assert_eq!(sa.shape, vec![3, 1]);
1005                assert_eq!(
1006                    sa.data,
1007                    vec![
1008                        String::from("g -> gpu"),
1009                        String::from("c => cuda"),
1010                        String::from("w :: wgpu")
1011                    ]
1012                );
1013            }
1014            other => panic!("expected string array, got {other:?}"),
1015        }
1016    }
1017
1018    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1019    #[test]
1020    fn join_3d_string_array_along_third_dimension() {
1021        let mut data = Vec::new();
1022        for page in 0..2 {
1023            for col in 0..2 {
1024                for row in 0..2 {
1025                    data.push(format!("r{row}c{col}p{page}"));
1026                }
1027            }
1028        }
1029        let array = StringArray::new(data, vec![2, 2, 2]).unwrap();
1030        let result = join_builtin(
1031            Value::StringArray(array),
1032            vec![Value::String(":".into()), Value::Int(IntValue::I32(3))],
1033        )
1034        .expect("join");
1035        match result {
1036            Value::StringArray(sa) => {
1037                assert_eq!(sa.shape, vec![2, 2, 1]);
1038                let expected = vec![
1039                    String::from("r0c0p0:r0c0p1"),
1040                    String::from("r1c0p0:r1c0p1"),
1041                    String::from("r0c1p0:r0c1p1"),
1042                    String::from("r1c1p0:r1c1p1"),
1043                ];
1044                assert_eq!(sa.data, expected);
1045            }
1046            other => panic!("expected string array, got {other:?}"),
1047        }
1048    }
1049
1050    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1051    #[test]
1052    fn join_errors_on_zero_dimension() {
1053        let array = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
1054        let err = join_builtin(
1055            Value::StringArray(array),
1056            vec![Value::Int(IntValue::I32(0))],
1057        )
1058        .unwrap_err();
1059        let err_text = err.to_string();
1060        assert!(
1061            err_text.contains("dimension"),
1062            "expected dimension error, got {err_text}"
1063        );
1064    }
1065
1066    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1067    #[test]
1068    fn join_errors_on_mismatched_delimiter_shape() {
1069        let array = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![1, 3]).unwrap();
1070        let delims =
1071            StringArray::new(vec!["+".into(), "-".into(), "=".into()], vec![1, 3]).unwrap();
1072        let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)]);
1073        assert!(result.is_err());
1074    }
1075
1076    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1077    #[test]
1078    fn join_propagates_missing_strings() {
1079        let array = StringArray::new(vec!["GPU".into(), "<missing>".into()], vec![1, 2]).unwrap();
1080        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
1081        match result {
1082            Value::StringArray(sa) => {
1083                assert_eq!(sa.data, vec![String::from("<missing>")]);
1084            }
1085            other => panic!("expected string array, got {other:?}"),
1086        }
1087    }
1088
1089    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1090    #[test]
1091    fn join_accepts_char_delimiter_scalar() {
1092        let array = StringArray::new(vec!["A".into(), "B".into()], vec![1, 2]).unwrap();
1093        let delimiter_chars = CharArray::new("++".chars().collect::<Vec<char>>(), 1, 2).unwrap();
1094        let result = join_builtin(
1095            Value::StringArray(array),
1096            vec![Value::CharArray(delimiter_chars)],
1097        )
1098        .expect("join");
1099        match result {
1100            Value::StringArray(sa) => {
1101                assert_eq!(sa.data, vec![String::from("A++B")]);
1102            }
1103            other => panic!("expected string array, got {other:?}"),
1104        }
1105    }
1106
1107    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1108    #[test]
1109    fn join_handles_empty_axis() {
1110        let array = StringArray::new(Vec::new(), vec![2, 0]).unwrap();
1111        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
1112        match result {
1113            Value::StringArray(sa) => {
1114                assert_eq!(sa.shape, vec![2, 1]);
1115                assert_eq!(sa.data, vec![String::from(""), String::from("")]);
1116            }
1117            other => panic!("expected string array, got {other:?}"),
1118        }
1119    }
1120
1121    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1122    #[test]
1123    fn join_missing_dimension_broadcast_delimiters() {
1124        let array = StringArray::new(
1125            vec!["aa".into(), "cc".into(), "bb".into(), "dd".into()],
1126            vec![2, 2],
1127        )
1128        .unwrap();
1129        let delims = StringArray::new(vec!["-".into()], vec![1, 1]).unwrap();
1130        let result = join_builtin(
1131            Value::StringArray(array),
1132            vec![Value::StringArray(delims), Value::Int(IntValue::I32(2))],
1133        )
1134        .expect("join");
1135        match result {
1136            Value::StringArray(sa) => {
1137                assert_eq!(sa.shape, vec![2, 1]);
1138                assert_eq!(sa.data, vec![String::from("aa-bb"), String::from("cc-dd")]);
1139            }
1140            other => panic!("expected string array, got {other:?}"),
1141        }
1142    }
1143
1144    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1145    #[test]
1146    #[cfg(feature = "wgpu")]
1147    fn join_executes_with_wgpu_provider_registered() {
1148        let _ = wgpu_backend::register_wgpu_provider(wgpu_backend::WgpuProviderOptions::default());
1149        let array = StringArray::new(vec!["GPU".into(), "Planner".into()], vec![2, 1]).unwrap();
1150        let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
1151        match result {
1152            Value::StringArray(sa) => {
1153                assert_eq!(sa.data, vec![String::from("GPU Planner")]);
1154            }
1155            other => panic!("expected string array, got {other:?}"),
1156        }
1157    }
1158
1159    #[test]
1160    fn join_type_concatenates_text() {
1161        assert_eq!(
1162            text_concat_type(&[Type::String], &ResolveContext::new(Vec::new())),
1163            Type::String
1164        );
1165    }
1166}