1use runmat_builtins::{CellArray, CharArray, StringArray, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::spec::{
8 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
9 ReductionNaN, ResidencyPolicy, ShapeRequirements,
10};
11use crate::builtins::strings::common::{char_row_to_string_slice, is_missing_string};
12use crate::builtins::strings::type_resolvers::text_concat_type;
13use crate::{build_runtime_error, gather_if_needed_async, make_cell, BuiltinResult, RuntimeError};
14
15#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::transform::join")]
16pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
17 name: "join",
18 op_kind: GpuOpKind::Custom("string-transform"),
19 supported_precisions: &[],
20 broadcast: BroadcastSemantics::None,
21 provider_hooks: &[],
22 constant_strategy: ConstantStrategy::InlineLiteral,
23 residency: ResidencyPolicy::GatherImmediately,
24 nan_mode: ReductionNaN::Include,
25 two_pass_threshold: None,
26 workgroup_size: None,
27 accepts_nan_mode: false,
28 notes: "Executes on the host; GPU-resident inputs and delimiters are gathered before concatenation.",
29};
30
31#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::transform::join")]
32pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
33 name: "join",
34 shape: ShapeRequirements::Any,
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 elementwise: None,
37 reduction: None,
38 emits_nan: false,
39 notes: "Joins operate on CPU-managed text and are ineligible for fusion.",
40};
41
42const BUILTIN_NAME: &str = "join";
43const INPUT_TYPE_ERROR: &str =
44 "join: input must be a string array, string scalar, character array, or cell array of character vectors";
45const DELIMITER_TYPE_ERROR: &str =
46 "join: delimiter must be a string, character vector, string array, or cell array of character vectors";
47const DELIMITER_SIZE_ERROR: &str =
48 "join: size of delimiter array must match the size of str, with the join dimension reduced by one";
49const DIMENSION_TYPE_ERROR: &str = "join: dimension must be a positive integer scalar";
50
51fn runtime_error_for(message: impl Into<String>) -> RuntimeError {
52 build_runtime_error(message)
53 .with_builtin(BUILTIN_NAME)
54 .build()
55}
56
57fn map_flow(err: RuntimeError) -> RuntimeError {
58 map_control_flow_with_builtin(err, BUILTIN_NAME)
59}
60
61#[runtime_builtin(
62 name = "join",
63 category = "strings/transform",
64 summary = "Combine text across a specified dimension inserting delimiters between elements.",
65 keywords = "join,string join,concatenate strings,delimiters,cell array join",
66 accel = "none",
67 type_resolver(text_concat_type),
68 builtin_path = "crate::builtins::strings::transform::join"
69)]
70async fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
71 let text = gather_if_needed_async(&text).await.map_err(map_flow)?;
72 let mut args = Vec::with_capacity(rest.len());
73 for arg in rest {
74 args.push(gather_if_needed_async(&arg).await.map_err(map_flow)?);
75 }
76
77 let mut input = JoinInput::from_value(text)?;
78 let (delimiter_arg, dimension_arg) = parse_arguments(&args)?;
79
80 let mut shape = input.shape.clone();
81 if shape.is_empty() {
82 shape = vec![1, 1];
83 }
84
85 let default_dim = default_dimension(&shape);
86 let dimension = match dimension_arg {
87 Some(dim) => dim,
88 None => default_dim,
89 };
90
91 if dimension == 0 {
92 return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
93 }
94
95 let ndims = input.ndims();
96 if dimension > ndims {
97 return input.into_value();
98 }
99
100 let axis_idx = dimension - 1;
101 input.ensure_shape_len(dimension);
102 let full_shape = input.shape.clone();
103
104 let delimiter = Delimiter::from_value(delimiter_arg, &full_shape, axis_idx)?;
105
106 let (output_data, output_shape) = perform_join(&input.data, &full_shape, axis_idx, &delimiter);
107
108 input.build_output(output_data, output_shape)
109}
110
111fn parse_arguments(args: &[Value]) -> BuiltinResult<(Option<Value>, Option<usize>)> {
112 match args.len() {
113 0 => Ok((None, None)),
114 1 => {
115 if let Some(dim) = value_to_dimension(&args[0])? {
116 Ok((None, Some(dim)))
117 } else {
118 Ok((Some(args[0].clone()), None))
119 }
120 }
121 2 => {
122 if let Some(dim) = value_to_dimension(&args[1])? {
123 Ok((Some(args[0].clone()), Some(dim)))
124 } else if let Some(dim) = value_to_dimension(&args[0])? {
125 Ok((Some(args[1].clone()), Some(dim)))
126 } else {
127 Err(runtime_error_for(DIMENSION_TYPE_ERROR))
128 }
129 }
130 _ => Err(runtime_error_for("join: too many input arguments")),
131 }
132}
133
134fn default_dimension(shape: &[usize]) -> usize {
135 for (index, size) in shape.iter().enumerate().rev() {
136 if *size != 1 {
137 return index + 1;
138 }
139 }
140 2
141}
142
143fn value_to_dimension(value: &Value) -> BuiltinResult<Option<usize>> {
144 match value {
145 Value::Int(i) => {
146 let v = i.to_i64();
147 if v <= 0 {
148 return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
149 }
150 Ok(Some(v as usize))
151 }
152 Value::Num(n) => {
153 if !n.is_finite() || *n <= 0.0 {
154 return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
155 }
156 let rounded = n.round();
157 if (rounded - n).abs() > f64::EPSILON {
158 return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
159 }
160 Ok(Some(rounded as usize))
161 }
162 Value::Tensor(t) if t.data.len() == 1 => {
163 let val = t.data[0];
164 if !val.is_finite() || val <= 0.0 {
165 return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
166 }
167 let rounded = val.round();
168 if (rounded - val).abs() > f64::EPSILON {
169 return Err(runtime_error_for(DIMENSION_TYPE_ERROR));
170 }
171 Ok(Some(rounded as usize))
172 }
173 _ => Ok(None),
174 }
175}
176
177struct JoinInput {
178 data: Vec<String>,
179 shape: Vec<usize>,
180 kind: OutputKind,
181}
182
183#[derive(Clone)]
184enum OutputKind {
185 StringScalar,
186 StringArray,
187 CellArray,
188}
189
190impl JoinInput {
191 fn from_value(value: Value) -> BuiltinResult<Self> {
192 match value {
193 Value::String(text) => Ok(Self {
194 data: vec![text],
195 shape: vec![1, 1],
196 kind: OutputKind::StringScalar,
197 }),
198 Value::StringArray(array) => Ok(Self {
199 data: array.data,
200 shape: array.shape,
201 kind: OutputKind::StringArray,
202 }),
203 Value::CharArray(array) => {
204 let strings = char_array_rows_to_strings(&array);
205 Ok(Self {
206 data: strings,
207 shape: vec![array.rows, 1],
208 kind: OutputKind::StringArray,
209 })
210 }
211 Value::Cell(cell) => {
212 let (data, shape) = cell_array_to_strings(cell)?;
213 Ok(Self {
214 data,
215 shape,
216 kind: OutputKind::CellArray,
217 })
218 }
219 _ => Err(runtime_error_for(INPUT_TYPE_ERROR)),
220 }
221 }
222
223 fn ndims(&self) -> usize {
224 if self.shape.is_empty() {
225 2
226 } else {
227 self.shape.len().max(2)
228 }
229 }
230
231 fn ensure_shape_len(&mut self, dimension: usize) {
232 if self.shape.len() < dimension {
233 self.shape.resize(dimension, 1);
234 }
235 }
236
237 fn into_value(self) -> BuiltinResult<Value> {
238 build_value(self.kind, self.data, self.shape)
239 }
240
241 fn build_output(&self, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
242 build_value(self.kind.clone(), data, shape)
243 }
244}
245
246fn build_value(kind: OutputKind, data: Vec<String>, shape: Vec<usize>) -> BuiltinResult<Value> {
247 match kind {
248 OutputKind::StringScalar => Ok(Value::String(data.into_iter().next().unwrap_or_default())),
249 OutputKind::StringArray => {
250 let array = StringArray::new(data, shape)
251 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
252 Ok(Value::StringArray(array))
253 }
254 OutputKind::CellArray => {
255 let rows = shape.first().copied().unwrap_or(0);
256 let cols = shape.get(1).copied().unwrap_or(1);
257 if rows == 0 || cols == 0 || data.is_empty() {
258 return make_cell(Vec::new(), rows, cols)
259 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")));
260 }
261 let mut values = Vec::with_capacity(rows * cols);
262 for row in 0..rows {
263 for col in 0..cols {
264 let idx = row + col * rows;
265 let text = data[idx].clone();
266 let chars: Vec<char> = text.chars().collect();
267 let cols_count = chars.len();
268 let char_array = CharArray::new(chars, 1, cols_count)
269 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))?;
270 values.push(Value::CharArray(char_array));
271 }
272 }
273 make_cell(values, rows, cols)
274 .map_err(|e| runtime_error_for(format!("{BUILTIN_NAME}: {e}")))
275 }
276 }
277}
278
279fn char_array_rows_to_strings(array: &CharArray) -> Vec<String> {
280 let mut strings = Vec::with_capacity(array.rows);
281 for row in 0..array.rows {
282 strings.push(char_row_to_string_slice(&array.data, array.cols, row));
283 }
284 strings
285}
286
287fn cell_array_to_strings(cell: CellArray) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
288 let CellArray {
289 data, rows, cols, ..
290 } = cell;
291 let mut strings = Vec::with_capacity(rows * cols);
292 for col in 0..cols {
293 for row in 0..rows {
294 let idx = row * cols + col;
295 strings.push(
296 cell_element_to_string(&data[idx])
297 .ok_or_else(|| runtime_error_for(INPUT_TYPE_ERROR))?,
298 );
299 }
300 }
301 Ok((strings, vec![rows, cols]))
302}
303
304fn cell_element_to_string(value: &Value) -> Option<String> {
305 match value {
306 Value::String(text) => Some(text.clone()),
307 Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
308 Value::CharArray(array) if array.rows <= 1 => {
309 if array.rows == 0 {
310 Some(String::new())
311 } else {
312 Some(char_row_to_string_slice(&array.data, array.cols, 0))
313 }
314 }
315 _ => None,
316 }
317}
318
319#[derive(Clone)]
320enum Delimiter {
321 Scalar(String),
322 Array(DelimiterArray),
323}
324
325#[derive(Clone)]
326struct DelimiterArray {
327 data: Vec<String>,
328 shape: Vec<usize>,
329 strides: Vec<usize>,
330}
331
332impl Delimiter {
333 fn from_value(
334 value: Option<Value>,
335 full_shape: &[usize],
336 axis_idx: usize,
337 ) -> BuiltinResult<Self> {
338 match value {
339 None => Ok(Self::Scalar(" ".to_string())),
340 Some(v) => {
341 if let Some(text) = value_to_scalar_string(&v) {
342 return Ok(Self::Scalar(text));
343 }
344 let (data, shape) = value_to_string_array(v)?;
345 let normalized = normalize_delimiter_shape(shape, full_shape, axis_idx)?;
346 let strides = compute_strides(&normalized);
347 Ok(Self::Array(DelimiterArray {
348 data,
349 shape: normalized,
350 strides,
351 }))
352 }
353 }
354 }
355
356 fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
357 match self {
358 Delimiter::Scalar(text) => text.as_str(),
359 Delimiter::Array(array) => array.value(coords, axis_idx, axis_gap),
360 }
361 }
362}
363
364impl DelimiterArray {
365 fn value<'a>(&'a self, coords: &[usize], axis_idx: usize, axis_gap: usize) -> &'a str {
366 let mut offset = 0usize;
367 for (dim, stride) in self.strides.iter().enumerate() {
368 let size = self.shape[dim];
369 let coord = if dim == axis_idx {
370 axis_gap.min(size.saturating_sub(1))
371 } else if size == 1 {
372 0
373 } else {
374 coords[dim].min(size.saturating_sub(1))
375 };
376 offset += coord * stride;
377 }
378 &self.data[offset]
379 }
380}
381
382fn value_to_scalar_string(value: &Value) -> Option<String> {
383 match value {
384 Value::String(text) => Some(text.clone()),
385 Value::CharArray(array) if array.rows <= 1 => {
386 if array.rows == 0 {
387 Some(String::new())
388 } else {
389 Some(char_row_to_string_slice(&array.data, array.cols, 0))
390 }
391 }
392 Value::StringArray(array) if array.data.len() == 1 => Some(array.data[0].clone()),
393 Value::Cell(cell) if cell.data.len() == 1 => cell_element_to_string(&cell.data[0]),
394 _ => None,
395 }
396}
397
398fn value_to_string_array(value: Value) -> BuiltinResult<(Vec<String>, Vec<usize>)> {
399 match value {
400 Value::StringArray(array) => Ok((array.data, array.shape)),
401 Value::Cell(cell) => {
402 let (data, shape) = cell_array_to_strings(cell)?;
403 Ok((data, shape))
404 }
405 Value::CharArray(array) => {
406 let rows = array.rows;
407 let strings = char_array_rows_to_strings(&array);
408 Ok((strings, vec![rows, 1]))
409 }
410 _ => Err(runtime_error_for(DELIMITER_TYPE_ERROR)),
411 }
412}
413
414fn normalize_delimiter_shape(
415 mut shape: Vec<usize>,
416 full_shape: &[usize],
417 axis_idx: usize,
418) -> BuiltinResult<Vec<usize>> {
419 if shape.len() > full_shape.len() {
420 return Err(runtime_error_for(DELIMITER_SIZE_ERROR));
421 }
422 if shape.len() < full_shape.len() {
423 shape.resize(full_shape.len(), 1);
424 }
425
426 let axis_len = full_shape[axis_idx].saturating_sub(1);
427 if axis_len == 0 {
428 shape[axis_idx] = 1;
429 } else if shape[axis_idx] != axis_len {
430 return Err(runtime_error_for(DELIMITER_SIZE_ERROR));
431 }
432
433 for (dim, size) in shape.iter().enumerate() {
434 if dim == axis_idx {
435 continue;
436 }
437 let reference = full_shape[dim];
438 if *size != reference && *size != 1 {
439 return Err(runtime_error_for(DELIMITER_SIZE_ERROR));
440 }
441 }
442
443 Ok(shape)
444}
445
446fn perform_join(
447 data: &[String],
448 full_shape: &[usize],
449 axis_idx: usize,
450 delimiter: &Delimiter,
451) -> (Vec<String>, Vec<usize>) {
452 if full_shape.is_empty() {
453 return (vec![String::new()], vec![1, 1]);
454 }
455
456 let axis_len = full_shape[axis_idx];
457 let mut output_shape = full_shape.to_vec();
458
459 let rest_size = full_shape
460 .iter()
461 .enumerate()
462 .filter(|(idx, _)| *idx != axis_idx)
463 .fold(1usize, |acc, (_, size)| acc.saturating_mul(*size));
464
465 if rest_size == 0 {
466 output_shape[axis_idx] = 0;
467 return (Vec::new(), output_shape);
468 }
469
470 output_shape[axis_idx] = 1;
471
472 let total_output = rest_size;
473 let mut output = Vec::with_capacity(total_output);
474
475 let strides = compute_strides(full_shape);
476 let axis_stride = strides[axis_idx];
477 let dims = full_shape.len();
478 let mut coords = vec![0usize; dims];
479
480 for _ in 0..rest_size {
481 let mut base_offset = 0usize;
482 for dim in 0..dims {
483 base_offset += coords[dim] * strides[dim];
484 }
485
486 if axis_len == 0 {
487 output.push(String::new());
488 } else {
489 let mut result = String::new();
490 let mut missing = false;
491 for axis_pos in 0..axis_len {
492 let element_offset = base_offset + axis_pos * axis_stride;
493 let value = &data[element_offset];
494 if is_missing_string(value) {
495 missing = true;
496 break;
497 }
498 if axis_pos > 0 {
499 let gap = axis_pos - 1;
500 let delim = delimiter.value(&coords, axis_idx, gap);
501 result.push_str(delim);
502 }
503 result.push_str(value);
504 }
505 if missing {
506 output.push("<missing>".to_string());
507 } else {
508 output.push(result);
509 }
510 }
511
512 increment_coords(&mut coords, full_shape, axis_idx);
513 }
514
515 (output, output_shape)
516}
517
518fn compute_strides(shape: &[usize]) -> Vec<usize> {
519 let mut strides = vec![1usize; shape.len()];
520 for dim in 1..shape.len() {
521 strides[dim] = strides[dim - 1].saturating_mul(shape[dim - 1]);
522 }
523 strides
524}
525
526fn increment_coords(coords: &mut [usize], shape: &[usize], axis_idx: usize) {
527 for dim in 0..shape.len() {
528 if dim == axis_idx {
529 continue;
530 }
531 coords[dim] += 1;
532 if coords[dim] < shape[dim] {
533 break;
534 }
535 coords[dim] = 0;
536 }
537}
538
539#[cfg(test)]
540pub(crate) mod tests {
541 use super::*;
542 #[cfg(feature = "wgpu")]
543 use runmat_accelerate::backend::wgpu::provider as wgpu_backend;
544 use runmat_builtins::{IntValue, ResolveContext, Type};
545
546 fn join_builtin(text: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
547 futures::executor::block_on(super::join_builtin(text, rest))
548 }
549
550 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
551 #[test]
552 fn join_string_array_default_dimension() {
553 let array = StringArray::new(
554 vec![
555 "Carlos".into(),
556 "Ella".into(),
557 "Diana".into(),
558 "Sada".into(),
559 "Olsen".into(),
560 "Lee".into(),
561 ],
562 vec![3, 2],
563 )
564 .unwrap();
565 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
566 match result {
567 Value::StringArray(sa) => {
568 assert_eq!(sa.shape, vec![3, 1]);
569 assert_eq!(
570 sa.data,
571 vec![
572 "Carlos Sada".to_string(),
573 "Ella Olsen".to_string(),
574 "Diana Lee".to_string()
575 ]
576 );
577 }
578 other => panic!("expected string array, got {other:?}"),
579 }
580 }
581
582 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
583 #[test]
584 fn join_with_custom_scalar_delimiter() {
585 let array = StringArray::new(
586 vec![
587 "x".into(),
588 "a".into(),
589 "y".into(),
590 "b".into(),
591 "z".into(),
592 "c".into(),
593 ],
594 vec![2, 3],
595 )
596 .unwrap();
597 let result =
598 join_builtin(Value::StringArray(array), vec![Value::String("-".into())]).expect("join");
599 match result {
600 Value::StringArray(sa) => {
601 assert_eq!(sa.shape, vec![2, 1]);
602 assert_eq!(sa.data, vec![String::from("x-y-z"), String::from("a-b-c")]);
603 }
604 other => panic!("expected string array, got {other:?}"),
605 }
606 }
607
608 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
609 #[test]
610 fn join_with_delimiter_array_per_row() {
611 let array = StringArray::new(
612 vec![
613 "x".into(),
614 "a".into(),
615 "y".into(),
616 "b".into(),
617 "z".into(),
618 "c".into(),
619 ],
620 vec![2, 3],
621 )
622 .unwrap();
623 let delims = StringArray::new(
624 vec![" + ".into(), " - ".into(), " = ".into(), " = ".into()],
625 vec![2, 2],
626 )
627 .unwrap();
628 let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)])
629 .expect("join");
630 match result {
631 Value::StringArray(sa) => {
632 assert_eq!(sa.shape, vec![2, 1]);
633 assert_eq!(
634 sa.data,
635 vec![String::from("x + y = z"), String::from("a - b = c")]
636 );
637 }
638 other => panic!("expected string array, got {other:?}"),
639 }
640 }
641
642 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
643 #[test]
644 fn join_with_dimension_argument() {
645 let array = StringArray::new(
646 vec![
647 "Carlos".into(),
648 "Ella".into(),
649 "Diana".into(),
650 "Sada".into(),
651 "Olsen".into(),
652 "Lee".into(),
653 ],
654 vec![3, 2],
655 )
656 .unwrap();
657 let result = join_builtin(
658 Value::StringArray(array),
659 vec![Value::Int(IntValue::I32(1))],
660 )
661 .expect("join");
662 match result {
663 Value::StringArray(sa) => {
664 assert_eq!(sa.shape, vec![1, 2]);
665 assert_eq!(
666 sa.data,
667 vec![
668 String::from("Carlos Ella Diana"),
669 String::from("Sada Olsen Lee"),
670 ]
671 );
672 }
673 other => panic!("expected string array, got {other:?}"),
674 }
675 }
676
677 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
678 #[test]
679 fn join_dimension_greater_than_ndims_returns_input() {
680 let array = StringArray::new(vec!["a".into(), "b".into()], vec![1, 2]).unwrap();
681 let result = join_builtin(
682 Value::StringArray(array.clone()),
683 vec![Value::Int(IntValue::I32(4))],
684 )
685 .expect("join");
686 match result {
687 Value::StringArray(sa) => {
688 assert_eq!(sa.shape, array.shape);
689 assert_eq!(sa.data, array.data);
690 }
691 other => panic!("expected original array, got {other:?}"),
692 }
693 }
694
695 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
696 #[test]
697 fn join_cell_array_of_char_vectors() {
698 let gpu = CharArray::new_row("GPU");
699 let accel = CharArray::new_row("Accelerate");
700 let ignition = CharArray::new_row("Ignition");
701 let interpreter = CharArray::new_row("Interpreter");
702 let values = vec![
703 Value::CharArray(gpu),
704 Value::CharArray(accel),
705 Value::CharArray(ignition),
706 Value::CharArray(interpreter),
707 ];
708 let cell = make_cell(values, 2, 2).expect("cell");
709 let result = join_builtin(cell, vec![Value::String(", ".into())]).expect("join cell");
710 match result {
711 Value::Cell(cell_out) => {
712 assert_eq!(cell_out.rows, 2);
713 assert_eq!(cell_out.cols, 1);
714 let first = unsafe { &*cell_out.data[0].as_raw() };
715 let second = unsafe { &*cell_out.data[1].as_raw() };
716 match (first, second) {
717 (Value::CharArray(a), Value::CharArray(b)) => {
718 assert_eq!(
719 char_row_to_string_slice(&a.data, a.cols, 0),
720 "GPU, Accelerate"
721 );
722 assert_eq!(
723 char_row_to_string_slice(&b.data, b.cols, 0),
724 "Ignition, Interpreter"
725 );
726 }
727 other => panic!("expected char arrays, got {other:?}"),
728 }
729 }
730 other => panic!("expected cell array, got {other:?}"),
731 }
732 }
733
734 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
735 #[test]
736 fn join_with_numeric_second_argument_uses_default_delimiter() {
737 let array = StringArray::new(
738 vec!["RunMat".into(), "Accelerate".into(), "Planner".into()],
739 vec![3, 1],
740 )
741 .unwrap();
742 let result = join_builtin(
743 Value::StringArray(array),
744 vec![Value::Int(IntValue::I32(1))],
745 )
746 .expect("join");
747 match result {
748 Value::StringArray(sa) => {
749 assert_eq!(sa.shape, vec![1, 1]);
750 assert_eq!(sa.data, vec![String::from("RunMat Accelerate Planner")]);
751 }
752 other => panic!("expected string array, got {other:?}"),
753 }
754 }
755
756 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
757 #[test]
758 fn join_char_array_input_produces_string_array() {
759 let data: Vec<char> = "RunMatGPUDev".chars().collect();
760 let char_array = CharArray::new(data, 3, 4).unwrap();
761 let result = join_builtin(Value::CharArray(char_array), Vec::new()).expect("join");
762 match result {
763 Value::StringArray(sa) => {
764 assert_eq!(sa.shape, vec![1, 1]);
765 assert_eq!(sa.data, vec![String::from("RunM atGP UDev")]);
766 }
767 other => panic!("expected string array, got {other:?}"),
768 }
769 }
770
771 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
772 #[test]
773 fn join_with_cell_delimiter_array() {
774 let array = StringArray::new(
775 vec![
776 "g".into(),
777 "c".into(),
778 "w".into(),
779 "gpu".into(),
780 "cuda".into(),
781 "wgpu".into(),
782 ],
783 vec![3, 2],
784 )
785 .unwrap();
786 let delimiters = make_cell(
787 vec![
788 Value::String(String::from(" -> ")),
789 Value::String(String::from(" => ")),
790 Value::String(String::from(" :: ")),
791 ],
792 3,
793 1,
794 )
795 .expect("cell");
796 let result = join_builtin(
797 Value::StringArray(array),
798 vec![delimiters, Value::Int(IntValue::I32(2))],
799 )
800 .expect("join");
801 match result {
802 Value::StringArray(sa) => {
803 assert_eq!(sa.shape, vec![3, 1]);
804 assert_eq!(
805 sa.data,
806 vec![
807 String::from("g -> gpu"),
808 String::from("c => cuda"),
809 String::from("w :: wgpu")
810 ]
811 );
812 }
813 other => panic!("expected string array, got {other:?}"),
814 }
815 }
816
817 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
818 #[test]
819 fn join_3d_string_array_along_third_dimension() {
820 let mut data = Vec::new();
821 for page in 0..2 {
822 for col in 0..2 {
823 for row in 0..2 {
824 data.push(format!("r{row}c{col}p{page}"));
825 }
826 }
827 }
828 let array = StringArray::new(data, vec![2, 2, 2]).unwrap();
829 let result = join_builtin(
830 Value::StringArray(array),
831 vec![Value::String(":".into()), Value::Int(IntValue::I32(3))],
832 )
833 .expect("join");
834 match result {
835 Value::StringArray(sa) => {
836 assert_eq!(sa.shape, vec![2, 2, 1]);
837 let expected = vec![
838 String::from("r0c0p0:r0c0p1"),
839 String::from("r1c0p0:r1c0p1"),
840 String::from("r0c1p0:r0c1p1"),
841 String::from("r1c1p0:r1c1p1"),
842 ];
843 assert_eq!(sa.data, expected);
844 }
845 other => panic!("expected string array, got {other:?}"),
846 }
847 }
848
849 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
850 #[test]
851 fn join_errors_on_zero_dimension() {
852 let array = StringArray::new(vec!["a".into()], vec![1, 1]).unwrap();
853 let err = join_builtin(
854 Value::StringArray(array),
855 vec![Value::Int(IntValue::I32(0))],
856 )
857 .unwrap_err();
858 let err_text = err.to_string();
859 assert!(
860 err_text.contains("dimension"),
861 "expected dimension error, got {err_text}"
862 );
863 }
864
865 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
866 #[test]
867 fn join_errors_on_mismatched_delimiter_shape() {
868 let array = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![1, 3]).unwrap();
869 let delims =
870 StringArray::new(vec!["+".into(), "-".into(), "=".into()], vec![1, 3]).unwrap();
871 let result = join_builtin(Value::StringArray(array), vec![Value::StringArray(delims)]);
872 assert!(result.is_err());
873 }
874
875 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
876 #[test]
877 fn join_propagates_missing_strings() {
878 let array = StringArray::new(vec!["GPU".into(), "<missing>".into()], vec![1, 2]).unwrap();
879 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
880 match result {
881 Value::StringArray(sa) => {
882 assert_eq!(sa.data, vec![String::from("<missing>")]);
883 }
884 other => panic!("expected string array, got {other:?}"),
885 }
886 }
887
888 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
889 #[test]
890 fn join_accepts_char_delimiter_scalar() {
891 let array = StringArray::new(vec!["A".into(), "B".into()], vec![1, 2]).unwrap();
892 let delimiter_chars = CharArray::new("++".chars().collect::<Vec<char>>(), 1, 2).unwrap();
893 let result = join_builtin(
894 Value::StringArray(array),
895 vec![Value::CharArray(delimiter_chars)],
896 )
897 .expect("join");
898 match result {
899 Value::StringArray(sa) => {
900 assert_eq!(sa.data, vec![String::from("A++B")]);
901 }
902 other => panic!("expected string array, got {other:?}"),
903 }
904 }
905
906 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
907 #[test]
908 fn join_handles_empty_axis() {
909 let array = StringArray::new(Vec::new(), vec![2, 0]).unwrap();
910 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
911 match result {
912 Value::StringArray(sa) => {
913 assert_eq!(sa.shape, vec![2, 1]);
914 assert_eq!(sa.data, vec![String::from(""), String::from("")]);
915 }
916 other => panic!("expected string array, got {other:?}"),
917 }
918 }
919
920 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
921 #[test]
922 fn join_missing_dimension_broadcast_delimiters() {
923 let array = StringArray::new(
924 vec!["aa".into(), "cc".into(), "bb".into(), "dd".into()],
925 vec![2, 2],
926 )
927 .unwrap();
928 let delims = StringArray::new(vec!["-".into()], vec![1, 1]).unwrap();
929 let result = join_builtin(
930 Value::StringArray(array),
931 vec![Value::StringArray(delims), Value::Int(IntValue::I32(2))],
932 )
933 .expect("join");
934 match result {
935 Value::StringArray(sa) => {
936 assert_eq!(sa.shape, vec![2, 1]);
937 assert_eq!(sa.data, vec![String::from("aa-bb"), String::from("cc-dd")]);
938 }
939 other => panic!("expected string array, got {other:?}"),
940 }
941 }
942
943 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
944 #[test]
945 #[cfg(feature = "wgpu")]
946 fn join_executes_with_wgpu_provider_registered() {
947 let _ = wgpu_backend::register_wgpu_provider(wgpu_backend::WgpuProviderOptions::default());
948 let array = StringArray::new(vec!["GPU".into(), "Planner".into()], vec![2, 1]).unwrap();
949 let result = join_builtin(Value::StringArray(array), Vec::new()).expect("join");
950 match result {
951 Value::StringArray(sa) => {
952 assert_eq!(sa.data, vec![String::from("GPU Planner")]);
953 }
954 other => panic!("expected string array, got {other:?}"),
955 }
956 }
957
958 #[test]
959 fn join_type_concatenates_text() {
960 assert_eq!(
961 text_concat_type(&[Type::String], &ResolveContext::new(Vec::new())),
962 Type::String
963 );
964 }
965}