1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6};
7use runmat_builtins::{CellArray, CharArray, StringArray, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
16use crate::builtins::strings::type_resolvers::text_concat_type;
17use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::join")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "join",
22 op_kind: GpuOpKind::Custom("string-transform"),
23 supported_precisions: &[],
24 broadcast: BroadcastSemantics::None,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes: "Executes on the host; GPU-resident inputs and delimiters are gathered before concatenation.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::join")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37 name: "join",
38 shape: ShapeRequirements::Any,
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 elementwise: None,
41 reduction: None,
42 emits_nan: false,
43 notes: "Joins operate on CPU-managed text and are ineligible for fusion.",
44};
45
46const BUILTIN_NAME: &str = "join";
47
48const JOIN_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
49 name: "out",
50 ty: BuiltinParamType::Any,
51 arity: BuiltinParamArity::Required,
52 default: None,
53 description: "Joined text preserving join output container semantics.",
54}];
55
56const JOIN_INPUTS_BASE: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
57 name: "str",
58 ty: BuiltinParamType::Any,
59 arity: BuiltinParamArity::Required,
60 default: None,
61 description: "Input text (string/char/cell).",
62}];
63
64const JOIN_INPUTS_DELIMITER: [BuiltinParamDescriptor; 2] = [
65 BuiltinParamDescriptor {
66 name: "str",
67 ty: BuiltinParamType::Any,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Input text (string/char/cell).",
71 },
72 BuiltinParamDescriptor {
73 name: "delimiter",
74 ty: BuiltinParamType::Any,
75 arity: BuiltinParamArity::Required,
76 default: Some("\" \""),
77 description: "Delimiter scalar or delimiter array matching join shape constraints.",
78 },
79];
80
81const JOIN_INPUTS_DIM: [BuiltinParamDescriptor; 2] = [
82 BuiltinParamDescriptor {
83 name: "str",
84 ty: BuiltinParamType::Any,
85 arity: BuiltinParamArity::Required,
86 default: None,
87 description: "Input text (string/char/cell).",
88 },
89 BuiltinParamDescriptor {
90 name: "dim",
91 ty: BuiltinParamType::IntegerScalar,
92 arity: BuiltinParamArity::Required,
93 default: None,
94 description: "Positive dimension index to join along.",
95 },
96];
97
98const JOIN_INPUTS_DELIMITER_DIM: [BuiltinParamDescriptor; 3] = [
99 BuiltinParamDescriptor {
100 name: "str",
101 ty: BuiltinParamType::Any,
102 arity: BuiltinParamArity::Required,
103 default: None,
104 description: "Input text (string/char/cell).",
105 },
106 BuiltinParamDescriptor {
107 name: "delimiter",
108 ty: BuiltinParamType::Any,
109 arity: BuiltinParamArity::Required,
110 default: None,
111 description: "Delimiter scalar or delimiter array matching join shape constraints.",
112 },
113 BuiltinParamDescriptor {
114 name: "dim",
115 ty: BuiltinParamType::IntegerScalar,
116 arity: BuiltinParamArity::Required,
117 default: None,
118 description: "Positive dimension index to join along.",
119 },
120];
121
122const JOIN_INPUTS_DIM_DELIMITER: [BuiltinParamDescriptor; 3] = [
123 BuiltinParamDescriptor {
124 name: "str",
125 ty: BuiltinParamType::Any,
126 arity: BuiltinParamArity::Required,
127 default: None,
128 description: "Input text (string/char/cell).",
129 },
130 BuiltinParamDescriptor {
131 name: "dim",
132 ty: BuiltinParamType::IntegerScalar,
133 arity: BuiltinParamArity::Required,
134 default: None,
135 description: "Positive dimension index to join along.",
136 },
137 BuiltinParamDescriptor {
138 name: "delimiter",
139 ty: BuiltinParamType::Any,
140 arity: BuiltinParamArity::Required,
141 default: None,
142 description: "Delimiter scalar or delimiter array matching join shape constraints.",
143 },
144];
145
146const JOIN_SIGNATURES: [BuiltinSignatureDescriptor; 5] = [
147 BuiltinSignatureDescriptor {
148 label: "out = join(str)",
149 inputs: &JOIN_INPUTS_BASE,
150 outputs: &JOIN_OUTPUT,
151 },
152 BuiltinSignatureDescriptor {
153 label: "out = join(str, delimiter)",
154 inputs: &JOIN_INPUTS_DELIMITER,
155 outputs: &JOIN_OUTPUT,
156 },
157 BuiltinSignatureDescriptor {
158 label: "out = join(str, dim)",
159 inputs: &JOIN_INPUTS_DIM,
160 outputs: &JOIN_OUTPUT,
161 },
162 BuiltinSignatureDescriptor {
163 label: "out = join(str, delimiter, dim)",
164 inputs: &JOIN_INPUTS_DELIMITER_DIM,
165 outputs: &JOIN_OUTPUT,
166 },
167 BuiltinSignatureDescriptor {
168 label: "out = join(str, dim, delimiter)",
169 inputs: &JOIN_INPUTS_DIM_DELIMITER,
170 outputs: &JOIN_OUTPUT,
171 },
172];
173
174const JOIN_ERROR_INPUT_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
175 code: "RM.JOIN.INPUT_TYPE",
176 identifier: Some("RunMat:join:InputType"),
177 when: "Input text is not a string array/scalar, char array, or cell array of text scalars.",
178 message:
179 "join: input must be a string array, string scalar, character array, or cell array of character vectors",
180};
181
182const JOIN_ERROR_DELIMITER_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
183 code: "RM.JOIN.DELIMITER_TYPE",
184 identifier: Some("RunMat:join:DelimiterType"),
185 when: "Delimiter is not a supported text scalar/array/cell value.",
186 message:
187 "join: delimiter must be a string, character vector, string array, or cell array of character vectors",
188};
189
190const JOIN_ERROR_DELIMITER_SIZE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
191 code: "RM.JOIN.DELIMITER_SIZE",
192 identifier: Some("RunMat:join:DelimiterSize"),
193 when: "Delimiter array shape does not match join shape constraints.",
194 message:
195 "join: size of delimiter array must match the size of str, with the join dimension reduced by one",
196};
197
198const JOIN_ERROR_DIMENSION_TYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
199 code: "RM.JOIN.DIMENSION_TYPE",
200 identifier: Some("RunMat:join:DimensionType"),
201 when: "Dimension argument is not a positive integer scalar.",
202 message: "join: dimension must be a positive integer scalar",
203};
204
205const JOIN_ERROR_ARG_COUNT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
206 code: "RM.JOIN.ARG_COUNT",
207 identifier: Some("RunMat:join:ArgCount"),
208 when: "More than three total arguments are supplied.",
209 message: "join: too many input arguments",
210};
211
212const JOIN_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
213 code: "RM.JOIN.INTERNAL",
214 identifier: Some("RunMat:join:InternalError"),
215 when: "Internal output container construction failed.",
216 message: "join: internal error",
217};
218
219const JOIN_ERRORS: [BuiltinErrorDescriptor; 6] = [
220 JOIN_ERROR_INPUT_TYPE,
221 JOIN_ERROR_DELIMITER_TYPE,
222 JOIN_ERROR_DELIMITER_SIZE,
223 JOIN_ERROR_DIMENSION_TYPE,
224 JOIN_ERROR_ARG_COUNT,
225 JOIN_ERROR_INTERNAL,
226];
227
228pub const JOIN_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
229 signatures: &JOIN_SIGNATURES,
230 output_mode: BuiltinOutputMode::Fixed,
231 completion_policy: BuiltinCompletionPolicy::Public,
232 errors: &JOIN_ERRORS,
233};
234
235fn map_flow(err: RuntimeError) -> RuntimeError {
236 map_control_flow_with_builtin(err, BUILTIN_NAME)
237}
238
239fn join_error_with_message(
240 message: impl Into<String>,
241 error: &'static BuiltinErrorDescriptor,
242) -> RuntimeError {
243 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
244 if let Some(identifier) = error.identifier {
245 builder = builder.with_identifier(identifier);
246 }
247 builder.build()
248}
249
250fn join_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
251 join_error_with_message(error.message, error)
252}
253
254#[runtime_builtin(
255 name = "join",
256 category = "strings/transform",
257 summary = "Join text elements with delimiters along a dimension.",
258 keywords = "join,string join,concatenate strings,delimiters,cell array join",
259 accel = "none",
260 type_resolver(text_concat_type),
261 descriptor(crate::builtins::strings::transform::join::JOIN_DESCRIPTOR),
262 builtin_path = "crate::builtins::strings::transform::join"
263)]
264async fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
265 let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
266 let mut args = Vec::with_capacity(rest.len());
267 for arg in rest {
268 args.push(gather_if_needed_async(&arg).await.map_err(map_flow)?);
269 }
270
271 let mut input = JoinInput::from_value(text)?;
272 let (delimiter_arg, dimension_arg) = parse_arguments(&args)?;
273
274 let mut shape = input.shape.clone();
275 if shape.is_empty() {
276 shape = vec![1, 1];
277 }
278
279 let default_dim = default_dimension(&shape);
280 let dimension = match dimension_arg {
281 Some(dim) => dim,
282 None => default_dim,
283 };
284
285 if dimension == 0 {
286 return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
287 }
288
289 let ndims = input.ndims();
290 if dimension > ndims {
291 return input.into_value();
292 }
293
294 let axis_idx = dimension - 1;
295 input.ensure_shape_len(dimension);
296 let full_shape = input.shape.clone();
297
298 let delimiter = Delimiter::from_value(delimiter_arg, &full_shape, axis_idx)?;
299
300 let (output_data, output_shape) = perform_join(&input.data, &full_shape, axis_idx, &delimiter);
301
302 input.build_output(output_data, output_shape)
303}
304
305fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<Value>, Option<usize>)> {
306 match args.len() {
307 0 => Ok((None, None)),
308 1 => {
309 if let Some(dim) = value_to_dimension(&args[0])? {
310 Ok((None, Some(dim)))
311 } else {
312 Ok((Some(args[0].clone()), None))
313 }
314 }
315 2 => {
316 if let Some(dim) = value_to_dimension(&args[1])? {
317 Ok((Some(args[0].clone()), Some(dim)))
318 } else if let Some(dim) = value_to_dimension(&args[0])? {
319 Ok((Some(args[1].clone()), Some(dim)))
320 } else {
321 Err(join_error(&JOIN_ERROR_DIMENSION_TYPE))
322 }
323 }
324 _ => Err(join_error(&JOIN_ERROR_ARG_COUNT)),
325 }
326}
327
328fn default_dimension(shape: &[usize]) -> usize {
329 for (index, size) in shape.iter().enumerate().rev() {
330 if *size != 1 {
331 return index + 1;
332 }
333 }
334 2
335}
336
337fn value_to_dimension(value: &Value) -> BuiltinResult<Option<usize>> {
338 match value {
339 Value::Int(i) => {
340 let v = i.to_i64();
341 if v <= 0 {
342 return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
343 }
344 Ok(Some(v as usize))
345 }
346 Value::Num(n) => {
347 if !n.is_finite() || *n <= 0.0 {
348 return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
349 }
350 let rounded = n.round();
351 if (rounded - n).abs() > f64::EPSILON {
352 return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
353 }
354 Ok(Some(rounded as usize))
355 }
356 Value::Tensor(t) if t.data.len() == 1 => {
357 let val = t.data[0];
358 if !val.is_finite() || val <= 0.0 {
359 return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
360 }
361 let rounded = val.round();
362 if (rounded - val).abs() > f64::EPSILON {
363 return Err(join_error(&JOIN_ERROR_DIMENSION_TYPE));
364 }
365 Ok(Some(rounded as usize))
366 }
367 _ => Ok(None),
368 }
369}
370
371struct JoinInput {
372 data: Vec<String>,
373 shape: Vec<usize>,
374 kind: OutputKind,
375}
376
377#[derive(Clone)]
378enum OutputKind {
379 StringScalar,
380 StringArray,
381 CellArray,
382}
383
384impl JoinInput {
385 fn from_value(value: Value) -> BuiltinResult<Self> {
386 match value {
387 Value::String(text) => Ok(Self {
388 data: vec![text],
389 shape: vec![1, 1],
390 kind: OutputKind::StringScalar,
391 }),
392 Value::StringArray(array) => Ok(Self {
393 data: array.data,
394 shape: array.shape,
395 kind: OutputKind::StringArray,
396 }),
397 Value::CharArray(array) => {
398 let strings = char_array_rows_to_strings(&array);
399 Ok(Self {
400 data: strings,
401 shape: vec![array.rows, 1],
402 kind: OutputKind::StringArray,
403 })
404 }
405 Value::Cell(cell) => {
406 let (data, shape) = cell_array_to_strings(cell)?;
407 Ok(Self {
408 data,
409 shape,
410 kind: OutputKind::CellArray,
411 })
412 }
413 _ => Err(join_error(&JOIN_ERROR_INPUT_TYPE)),
414 }
415 }
416
417 fn ndims(&self) -> usize {
418 if self.shape.is_empty() {
419 2
420 } else {
421 self.shape.len().max(2)
422 }
423 }
424
425 fn ensure_shape_len(&mut self, dimension: usize) {
426 if self.shape.len() < dimension {
427 self.shape.resize(dimension, 1);
428 }
429 }
430
431 fn into_value(self) -> BuiltinResult<Value> {
432 build_value(self.kind, self.data, self.shape)
433 }
434
435 fn build_output(&self, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
436 build_value(self.kind.clone(), data, shape)
437 }
438}
439
440fn build_value(kind: OutputKind, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
441 match kind {
442 OutputKind::StringScalar => Ok(Value::String(data.into_iter().next().unwrap_or_default())),
443 OutputKind::StringArray => {
444 let array = StringArray::new(data, shape).map_err(|e| {
445 join_error_with_message(format!("{BUILTIN_NAME}: {e}"), &JOIN_ERROR_INTERNAL)
446 })?;
447 Ok(Value::StringArray(array))
448 }
449 OutputKind::CellArray => {
450 let rows = shape.first().copied().unwrap_or(0);
451 let cols = shape.get(1).copied().unwrap_or(1);
452 if rows == 0 || cols == 0 || data.is_empty() {
453 return make_cell(Vec::new(), rows, cols).map_err(|e| {
454 join_error_with_message(format!("{BUILTIN_NAME}: {e}"), &JOIN_ERROR_INTERNAL)
455 });
456 }
457 let mut values = Vec::with_capacity(rows * cols);
458 for row in 0..rows {
459 for col in 0..cols {
460 let idx = row + col * rows;
461 let text = data[idx].clone();
462 let chars: Vec<char> = text.chars().collect();
463 let cols_count = chars.len();
464 let char_array = CharArray::new(chars, 1, cols_count).map_err(|e| {
465 join_error_with_message(
466 format!("{BUILTIN_NAME}: {e}"),
467 &JOIN_ERROR_INTERNAL,
468 )
469 })?;
470 values.push(Value::CharArray(char_array));
471 }
472 }
473 make_cell(values, rows, cols).map_err(|e| {
474 join_error_with_message(format!("{BUILTIN_NAME}: {e}"), &JOIN_ERROR_INTERNAL)
475 })
476 }
477 }
478}
479
480fn char_array_rows_to_strings(array: &CharArray) -> Vec<String> {
481 let mut strings = Vec::with_capacity(array.rows);
482 for row in 0..array.rows {
483 strings.push(char_row_to_string_slice(&array.data, array.cols, row));
484 }
485 strings
486}
487
488fn cell_array_to_strings(cell: CellArray) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
489 let CellArray {
490 data, rows, cols, ..
491 } = cell;
492 let mut strings = Vec::with_capacity(rows * cols);
493 for col in 0..cols {
494 for row in 0..rows {
495 let idx = row * cols + col;
496 strings.push(
497 cell_element_to_string(&data[idx])
498 .ok_or_else(|| join_error(&JOIN_ERROR_INPUT_TYPE))?,
499 );
500 }
501 }
502 Ok((strings, vec![rows, cols]))
503}
504
505fn cell_element_to_string(value: &Value) -> Option<String> {
506 match value {
507 Value::String(text) => Some(text.clone()),
508 Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
509 Value::CharArray(array) if array.rows <= 1 => {
510 if array.rows == 0 {
511 Some(String::new())
512 } else {
513 Some(char_row_to_string_slice(&array.data, array.cols, 0))
514 }
515 }
516 _ => None,
517 }
518}
519
520#[derive(Clone)]
521enum Delimiter {
522 Scalar(String),
523 Array(DelimiterArray),
524}
525
526#[derive(Clone)]
527struct DelimiterArray {
528 data: Vec<String>,
529 shape: Vec<usize>,
530 strides: Vec<usize>,
531}
532
533impl Delimiter {
534 fn from_value(
535 value: Option<Value>,
536 full_shape: &[usize],
537 axis_idx: usize,
538 ) -> BuiltinResult<Self> {
539 match value {
540 None => Ok(Self::Scalar(" ".to_string())),
541 Some(v) => {
542 if let Some(text) = value_to_scalar_string(&v) {
543 return Ok(Self::Scalar(text));
544 }
545 let (data, shape) = value_to_string_array(v)?;
546 let normalized = normalize_delimiter_shape(shape, full_shape, axis_idx)?;
547 let strides = compute_strides(&normalized);
548 Ok(Self::Array(DelimiterArray {
549 data,
550 shape: normalized,
551 strides,
552 }))
553 }
554 }
555 }
556
557 fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
558 match self {
559 Delimiter::Scalar(text) => text.as_str(),
560 Delimiter::Array(array) => array.value(coords, axis_idx, axis_gap),
561 }
562 }
563}
564
565impl DelimiterArray {
566 fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
567 let mut offset = 0usize;
568 for (dim, stride) in self.strides.iter().enumerate() {
569 let size = self.shape[dim];
570 let coord = if dim == axis_idx {
571 axis_gap.min(size.saturating_sub(1))
572 } else if size == 1 {
573 0
574 } else {
575 coords[dim].min(size.saturating_sub(1))
576 };
577 offset += coord * stride;
578 }
579 &self.data[offset]
580 }
581}
582
583fn value_to_scalar_string(value: &Value) -> Option<String> {
584 match value {
585 Value::String(text) => Some(text.clone()),
586 Value::CharArray(array) if array.rows <= 1 => {
587 if array.rows == 0 {
588 Some(String::new())
589 } else {
590 Some(char_row_to_string_slice(&array.data, array.cols, 0))
591 }
592 }
593 Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
594 Value::Cell(cell) if cell.data.len() == 1 => cell_element_to_string(&cell.data[0]),
595 _ => None,
596 }
597}
598
599fn value_to_string_array(value: Value) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
600 match value {
601 Value::StringArray(array) => Ok((array.data, array.shape)),
602 Value::Cell(cell) => {
603 let (data, shape) = cell_array_to_strings(cell)?;
604 Ok((data, shape))
605 }
606 Value::CharArray(array) => {
607 let rows = array.rows;
608 let strings = char_array_rows_to_strings(&array);
609 Ok((strings, vec![rows, 1]))
610 }
611 _ => Err(join_error(&JOIN_ERROR_DELIMITER_TYPE)),
612 }
613}
614
615fn normalize_delimiter_shape(
616 mut shape: Vec<usize>,
617 full_shape: &[usize],
618 axis_idx: usize,
619) -> BuiltinResult<Vec<usize>> {
620 if shape.len() > full_shape.len() {
621 return Err(join_error(&JOIN_ERROR_DELIMITER_SIZE));
622 }
623 if shape.len() < full_shape.len() {
624 shape.resize(full_shape.len(), 1);
625 }
626
627 let axis_len = full_shape[axis_idx].saturating_sub(1);
628 if axis_len == 0 {
629 shape[axis_idx] = 1;
630 } else if shape[axis_idx] != axis_len {
631 return Err(join_error(&JOIN_ERROR_DELIMITER_SIZE));
632 }
633
634 for (dim, size) in shape.iter().enumerate() {
635 if dim == axis_idx {
636 continue;
637 }
638 let reference = full_shape[dim];
639 if *size != reference && *size != 1 {
640 return Err(join_error(&JOIN_ERROR_DELIMITER_SIZE));
641 }
642 }
643
644 Ok(shape)
645}
646
647fn perform_join(
648 data: &[String],
649 full_shape: &[usize],
650 axis_idx: usize,
651 delimiter: &Delimiter,
652) -> (Vec<String>, Vec<usize>) {
653 if full_shape.is_empty() {
654 return (vec![String::new()], vec![1, 1]);
655 }
656
657 let axis_len = full_shape[axis_idx];
658 let mut output_shape = full_shape.to_vec();
659
660 let rest_size = full_shape
661 .iter()
662 .enumerate()
663 .filter(|(idx, _)| *idx != axis_idx)
664 .fold(1usize, |acc, (_, size)| acc.saturating_mul(*size));
665
666 if rest_size == 0 {
667 output_shape[axis_idx] = 0;
668 return (Vec::new(), output_shape);
669 }
670
671 output_shape[axis_idx] = 1;
672
673 let total_output = rest_size;
674 let mut output = Vec::with_capacity(total_output);
675
676 let strides = compute_strides(full_shape);
677 let axis_stride = strides[axis_idx];
678 let dims = full_shape.len();
679 let mut coords = vec![0usize; dims];
680
681 for _ in 0..rest_size {
682 let mut base_offset = 0usize;
683 for dim in 0..dims {
684 base_offset += coords[dim] * strides[dim];
685 }
686
687 if axis_len == 0 {
688 output.push(String::new());
689 } else {
690 let mut result = String::new();
691 let mut missing = false;
692 for axis_pos in 0..axis_len {
693 let element_offset = base_offset + axis_pos * axis_stride;
694 let value = &data[element_offset];
695 if is_missing_string(value) {
696 missing = true;
697 break;
698 }
699 if axis_pos > 0 {
700 let gap = axis_pos - 1;
701 let delim = delimiter.value(&coords, axis_idx, gap);
702 result.push_str(delim);
703 }
704 result.push_str(value);
705 }
706 if missing {
707 output.push("<missing>".to_string());
708 } else {
709 output.push(result);
710 }
711 }
712
713 increment_coords(&mut coords, full_shape, axis_idx);
714 }
715
716 (output, output_shape)
717}
718
719fn compute_strides(shape: &[usize]) -> Vec<usize> {
720 let mut strides = vec![1usize; shape.len()];
721 for dim in 1..shape.len() {
722 strides[dim] = strides[dim - 1].saturating_mul(shape[dim - 1]);
723 }
724 strides
725}
726
727fn increment_coords(coords: &mut [usize], shape: &[usize], axis_idx: usize) {
728 for dim in 0..shape.len() {
729 if dim == axis_idx {
730 continue;
731 }
732 coords[dim] += 1;
733 if coords[dim] < shape[dim] {
734 break;
735 }
736 coords[dim] = 0;
737 }
738}
739
740#[cfg(test)]
741pub(crate) mod tests {
742 use super::*;
743 #[cfg(feature = "wgpu")]
744 use runmat_accelerate::backend::wgpu::provider as wgpu_backend;
745 use runmat_builtins::{IntValue, ResolveContext, Type};
746
747 fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
748 futures::executor::block_on(super::join_builtin(text, rest))
749 }
750
751 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
752 #[test]
753 fn join_string_array_default_dimension() {
754 let array = StringArray::new(
755 vec![
756 "Carlos".into(),
757 "Ella".into(),
758 "Diana".into(),
759 "Sada".into(),
760 "Olsen".into(),
761 "Lee".into(),
762 ],
763 vec![3, 2],
764 )
765 .unwrap();
766 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
767 match result {
768 Value::StringArray(sa) => {
769 assert_eq!(sa.shape, vec![3, 1]);
770 assert_eq!(
771 sa.data,
772 vec![
773 "Carlos Sada".to_string(),
774 "Ella Olsen".to_string(),
775 "Diana Lee".to_string()
776 ]
777 );
778 }
779 other => panic!("expected string array, got {other:?}"),
780 }
781 }
782
783 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
784 #[test]
785 fn join_with_custom_scalar_delimiter() {
786 let array = StringArray::new(
787 vec![
788 "x".into(),
789 "a".into(),
790 "y".into(),
791 "b".into(),
792 "z".into(),
793 "c".into(),
794 ],
795 vec![2, 3],
796 )
797 .unwrap();
798 let result =
799 join_builtin(Value::StringArray(array), vec![Value::String("-".into())]).expect("join");
800 match result {
801 Value::StringArray(sa) => {
802 assert_eq!(sa.shape, vec![2, 1]);
803 assert_eq!(sa.data, vec![String::from("x-y-z"), String::from("a-b-c")]);
804 }
805 other => panic!("expected string array, got {other:?}"),
806 }
807 }
808
809 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
810 #[test]
811 fn join_with_delimiter_array_per_row() {
812 let array = StringArray::new(
813 vec![
814 "x".into(),
815 "a".into(),
816 "y".into(),
817 "b".into(),
818 "z".into(),
819 "c".into(),
820 ],
821 vec![2, 3],
822 )
823 .unwrap();
824 let delims = StringArray::new(
825 vec![" + ".into(), " - ".into(), " = ".into(), " = ".into()],
826 vec![2, 2],
827 )
828 .unwrap();
829 let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)])
830 .expect("join");
831 match result {
832 Value::StringArray(sa) => {
833 assert_eq!(sa.shape, vec![2, 1]);
834 assert_eq!(
835 sa.data,
836 vec![String::from("x + y = z"), String::from("a - b = c")]
837 );
838 }
839 other => panic!("expected string array, got {other:?}"),
840 }
841 }
842
843 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
844 #[test]
845 fn join_with_dimension_argument() {
846 let array = StringArray::new(
847 vec![
848 "Carlos".into(),
849 "Ella".into(),
850 "Diana".into(),
851 "Sada".into(),
852 "Olsen".into(),
853 "Lee".into(),
854 ],
855 vec![3, 2],
856 )
857 .unwrap();
858 let result = join_builtin(
859 Value::StringArray(array),
860 vec![Value::Int(IntValue::I32(1))],
861 )
862 .expect("join");
863 match result {
864 Value::StringArray(sa) => {
865 assert_eq!(sa.shape, vec![1, 2]);
866 assert_eq!(
867 sa.data,
868 vec![
869 String::from("Carlos Ella Diana"),
870 String::from("Sada Olsen Lee"),
871 ]
872 );
873 }
874 other => panic!("expected string array, got {other:?}"),
875 }
876 }
877
878 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
879 #[test]
880 fn join_dimension_greater_than_ndims_returns_input() {
881 let array = StringArray::new(vec!["a".into(), "b".into()], vec![1, 2]).unwrap();
882 let result = join_builtin(
883 Value::StringArray(array.clone()),
884 vec![Value::Int(IntValue::I32(4))],
885 )
886 .expect("join");
887 match result {
888 Value::StringArray(sa) => {
889 assert_eq!(sa.shape, array.shape);
890 assert_eq!(sa.data, array.data);
891 }
892 other => panic!("expected original array, got {other:?}"),
893 }
894 }
895
896 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
897 #[test]
898 fn join_cell_array_of_char_vectors() {
899 let gpu = CharArray::new_row("GPU");
900 let accel = CharArray::new_row("Accelerate");
901 let vm_label = CharArray::new_row("VM");
902 let interpreter = CharArray::new_row("Interpreter");
903 let values = vec![
904 Value::CharArray(gpu),
905 Value::CharArray(accel),
906 Value::CharArray(vm_label),
907 Value::CharArray(interpreter),
908 ];
909 let cell = make_cell(values, 2, 2).expect("cell");
910 let result = join_builtin(cell, vec![Value::String(", ".into())]).expect("join cell");
911 match result {
912 Value::Cell(cell_out) => {
913 assert_eq!(cell_out.rows, 2);
914 assert_eq!(cell_out.cols, 1);
915 let first = unsafe { &*cell_out.data[0].as_raw() };
916 let second = unsafe { &*cell_out.data[1].as_raw() };
917 match (first, second) {
918 (Value::CharArray(a), Value::CharArray(b)) => {
919 assert_eq!(
920 char_row_to_string_slice(&a.data, a.cols, 0),
921 "GPU, Accelerate"
922 );
923 assert_eq!(
924 char_row_to_string_slice(&b.data, b.cols, 0),
925 "VM, Interpreter"
926 );
927 }
928 other => panic!("expected char arrays, got {other:?}"),
929 }
930 }
931 other => panic!("expected cell array, got {other:?}"),
932 }
933 }
934
935 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
936 #[test]
937 fn join_with_numeric_second_argument_uses_default_delimiter() {
938 let array = StringArray::new(
939 vec!["RunMat".into(), "Accelerate".into(), "Planner".into()],
940 vec![3, 1],
941 )
942 .unwrap();
943 let result = join_builtin(
944 Value::StringArray(array),
945 vec![Value::Int(IntValue::I32(1))],
946 )
947 .expect("join");
948 match result {
949 Value::StringArray(sa) => {
950 assert_eq!(sa.shape, vec![1, 1]);
951 assert_eq!(sa.data, vec![String::from("RunMat Accelerate Planner")]);
952 }
953 other => panic!("expected string array, got {other:?}"),
954 }
955 }
956
957 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
958 #[test]
959 fn join_char_array_input_produces_string_array() {
960 let data: Vec<char> = "RunMatGPUDev".chars().collect();
961 let char_array = CharArray::new(data, 3, 4).unwrap();
962 let result = join_builtin(Value::CharArray(char_array), Vec::new()).expect("join");
963 match result {
964 Value::StringArray(sa) => {
965 assert_eq!(sa.shape, vec![1, 1]);
966 assert_eq!(sa.data, vec![String::from("RunM atGP UDev")]);
967 }
968 other => panic!("expected string array, got {other:?}"),
969 }
970 }
971
972 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
973 #[test]
974 fn join_with_cell_delimiter_array() {
975 let array = StringArray::new(
976 vec![
977 "g".into(),
978 "c".into(),
979 "w".into(),
980 "gpu".into(),
981 "cuda".into(),
982 "wgpu".into(),
983 ],
984 vec![3, 2],
985 )
986 .unwrap();
987 let delimiters = make_cell(
988 vec![
989 Value::String(String::from(" -> ")),
990 Value::String(String::from(" => ")),
991 Value::String(String::from(" :: ")),
992 ],
993 3,
994 1,
995 )
996 .expect("cell");
997 let result = join_builtin(
998 Value::StringArray(array),
999 vec![delimiters, Value::Int(IntValue::I32(2))],
1000 )
1001 .expect("join");
1002 match result {
1003 Value::StringArray(sa) => {
1004 assert_eq!(sa.shape, vec![3, 1]);
1005 assert_eq!(
1006 sa.data,
1007 vec![
1008 String::from("g -> gpu"),
1009 String::from("c => cuda"),
1010 String::from("w :: wgpu")
1011 ]
1012 );
1013 }
1014 other => panic!("expected string array, got {other:?}"),
1015 }
1016 }
1017
1018 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1019 #[test]
1020 fn join_3d_string_array_along_third_dimension() {
1021 let mut data = Vec::new();
1022 for page in 0..2 {
1023 for col in 0..2 {
1024 for row in 0..2 {
1025 data.push(format!("r{row}c{col}p{page}"));
1026 }
1027 }
1028 }
1029 let array = StringArray::new(data, vec![2, 2, 2]).unwrap();
1030 let result = join_builtin(
1031 Value::StringArray(array),
1032 vec![Value::String(":".into()), Value::Int(IntValue::I32(3))],
1033 )
1034 .expect("join");
1035 match result {
1036 Value::StringArray(sa) => {
1037 assert_eq!(sa.shape, vec![2, 2, 1]);
1038 let expected = vec![
1039 String::from("r0c0p0:r0c0p1"),
1040 String::from("r1c0p0:r1c0p1"),
1041 String::from("r0c1p0:r0c1p1"),
1042 String::from("r1c1p0:r1c1p1"),
1043 ];
1044 assert_eq!(sa.data, expected);
1045 }
1046 other => panic!("expected string array, got {other:?}"),
1047 }
1048 }
1049
1050 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1051 #[test]
1052 fn join_errors_on_zero_dimension() {
1053 let array = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
1054 let err = join_builtin(
1055 Value::StringArray(array),
1056 vec![Value::Int(IntValue::I32(0))],
1057 )
1058 .unwrap_err();
1059 let err_text = err.to_string();
1060 assert!(
1061 err_text.contains("dimension"),
1062 "expected dimension error, got {err_text}"
1063 );
1064 }
1065
1066 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1067 #[test]
1068 fn join_errors_on_mismatched_delimiter_shape() {
1069 let array = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![1, 3]).unwrap();
1070 let delims =
1071 StringArray::new(vec!["+".into(), "-".into(), "=".into()], vec![1, 3]).unwrap();
1072 let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)]);
1073 assert!(result.is_err());
1074 }
1075
1076 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1077 #[test]
1078 fn join_propagates_missing_strings() {
1079 let array = StringArray::new(vec!["GPU".into(), "<missing>".into()], vec![1, 2]).unwrap();
1080 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
1081 match result {
1082 Value::StringArray(sa) => {
1083 assert_eq!(sa.data, vec![String::from("<missing>")]);
1084 }
1085 other => panic!("expected string array, got {other:?}"),
1086 }
1087 }
1088
1089 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1090 #[test]
1091 fn join_accepts_char_delimiter_scalar() {
1092 let array = StringArray::new(vec!["A".into(), "B".into()], vec![1, 2]).unwrap();
1093 let delimiter_chars = CharArray::new("++".chars().collect::<Vec<char>>(), 1, 2).unwrap();
1094 let result = join_builtin(
1095 Value::StringArray(array),
1096 vec![Value::CharArray(delimiter_chars)],
1097 )
1098 .expect("join");
1099 match result {
1100 Value::StringArray(sa) => {
1101 assert_eq!(sa.data, vec![String::from("A++B")]);
1102 }
1103 other => panic!("expected string array, got {other:?}"),
1104 }
1105 }
1106
1107 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1108 #[test]
1109 fn join_handles_empty_axis() {
1110 let array = StringArray::new(Vec::new(), vec![2, 0]).unwrap();
1111 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
1112 match result {
1113 Value::StringArray(sa) => {
1114 assert_eq!(sa.shape, vec![2, 1]);
1115 assert_eq!(sa.data, vec![String::from(""), String::from("")]);
1116 }
1117 other => panic!("expected string array, got {other:?}"),
1118 }
1119 }
1120
1121 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1122 #[test]
1123 fn join_missing_dimension_broadcast_delimiters() {
1124 let array = StringArray::new(
1125 vec!["aa".into(), "cc".into(), "bb".into(), "dd".into()],
1126 vec![2, 2],
1127 )
1128 .unwrap();
1129 let delims = StringArray::new(vec!["-".into()], vec![1, 1]).unwrap();
1130 let result = join_builtin(
1131 Value::StringArray(array),
1132 vec![Value::StringArray(delims), Value::Int(IntValue::I32(2))],
1133 )
1134 .expect("join");
1135 match result {
1136 Value::StringArray(sa) => {
1137 assert_eq!(sa.shape, vec![2, 1]);
1138 assert_eq!(sa.data, vec![String::from("aa-bb"), String::from("cc-dd")]);
1139 }
1140 other => panic!("expected string array, got {other:?}"),
1141 }
1142 }
1143
1144 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1145 #[test]
1146 #[cfg(feature = "wgpu")]
1147 fn join_executes_with_wgpu_provider_registered() {
1148 let _ = wgpu_backend::register_wgpu_provider(wgpu_backend::WgpuProviderOptions::default());
1149 let array = StringArray::new(vec!["GPU".into(), "Planner".into()], vec![2, 1]).unwrap();
1150 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
1151 match result {
1152 Value::StringArray(sa) => {
1153 assert_eq!(sa.data, vec![String::from("GPU Planner")]);
1154 }
1155 other => panic!("expected string array, got {other:?}"),
1156 }
1157 }
1158
1159 #[test]
1160 fn join_type_concatenates_text() {
1161 assert_eq!(
1162 text_concat_type(&[Type::String], &ResolveContext::new(Vec::new())),
1163 Type::String
1164 );
1165 }
1166}