1use runmat_builtins::{LogicalArray, StringArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::map_control_flow_with_builtin;
7use crate::builtins::common::random_args::{keyword_of, shape_from_value};
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::strings::type_resolvers::string_array_type;
13use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
14
15const FN_NAME: &str = "strings";
16const SIZE_INTEGER_ERR: &str = "size inputs must be integers";
17const SIZE_NONNEGATIVE_ERR: &str = "size inputs must be nonnegative integers";
18const SIZE_FINITE_ERR: &str = "size inputs must be finite";
19const SIZE_NUMERIC_ERR: &str = "size arguments must be numeric scalars or vectors";
20const SIZE_SCALAR_ERR: &str = "size inputs must be scalar";
21
22fn strings_flow(message: impl Into<String>) -> RuntimeError {
23 build_runtime_error(message).with_builtin(FN_NAME).build()
24}
25
26fn remap_strings_flow(err: RuntimeError) -> RuntimeError {
27 map_control_flow_with_builtin(err, FN_NAME)
28}
29
30#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strings")]
31pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
32 name: FN_NAME,
33 op_kind: GpuOpKind::Custom("array_creation"),
34 supported_precisions: &[],
35 broadcast: BroadcastSemantics::None,
36 provider_hooks: &[],
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 residency: ResidencyPolicy::GatherImmediately,
39 nan_mode: ReductionNaN::Include,
40 two_pass_threshold: None,
41 workgroup_size: None,
42 accepts_nan_mode: false,
43 notes: "Runs entirely on the host; size arguments pulled from the GPU are gathered before allocation.",
44};
45
46#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strings")]
47pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
48 name: FN_NAME,
49 shape: ShapeRequirements::Any,
50 constant_strategy: ConstantStrategy::InlineLiteral,
51 elementwise: None,
52 reduction: None,
53 emits_nan: false,
54 notes: "Preallocates host string arrays; no fusion-supported kernels are generated.",
55};
56
57struct ParsedStrings {
58 shape: Vec<usize>,
59 fill: FillKind,
60}
61
62#[derive(Clone, Copy, PartialEq, Eq)]
63enum FillKind {
64 Empty,
65 Missing,
66}
67
68#[runtime_builtin(
69 name = "strings",
70 category = "strings/core",
71 summary = "Preallocate string arrays filled with empty string scalars.",
72 keywords = "strings,string array,empty,preallocate",
73 accel = "array_construct",
74 type_resolver(string_array_type),
75 builtin_path = "crate::builtins::strings::core::strings"
76)]
77async fn strings_builtin(rest: Vec<Value>) -> crate::BuiltinResult<Value> {
78 let ParsedStrings { shape, fill } = parse_arguments(rest).await?;
79 let total = shape.iter().try_fold(1usize, |acc, &dim| {
80 acc.checked_mul(dim).ok_or_else(|| {
81 strings_flow(format!("{FN_NAME}: requested size exceeds platform limits"))
82 })
83 })?;
84
85 let fill_text = match fill {
86 FillKind::Empty => String::new(),
87 FillKind::Missing => "<missing>".to_string(),
88 };
89
90 let mut data = Vec::with_capacity(total);
91 for _ in 0..total {
92 data.push(fill_text.clone());
93 }
94
95 let array =
96 StringArray::new(data, shape).map_err(|e| strings_flow(format!("{FN_NAME}: {e}")))?;
97 Ok(Value::StringArray(array))
98}
99
100async fn parse_arguments(args: Vec<Value>) -> BuiltinResult<ParsedStrings> {
101 let mut size_values: Vec<Value> = Vec::new();
102 let mut like_proto: Option<Value> = None;
103 let mut fill = FillKind::Empty;
104
105 let mut idx = 0;
106 while idx < args.len() {
107 let host = gather_if_needed_async(&args[idx])
108 .await
109 .map_err(remap_strings_flow)?;
110 if let Some(keyword) = keyword_of(&host) {
111 match keyword.as_str() {
112 "like" => {
113 if like_proto.is_some() {
114 return Err(strings_flow(format!(
115 "{FN_NAME}: multiple 'like' specifications are not supported"
116 )));
117 }
118 let Some(proto_raw) = args.get(idx + 1) else {
119 return Err(strings_flow(format!(
120 "{FN_NAME}: expected prototype after 'like'"
121 )));
122 };
123 let proto = gather_if_needed_async(proto_raw)
124 .await
125 .map_err(remap_strings_flow)?;
126 like_proto = Some(proto);
127 idx += 2;
128 continue;
129 }
130 "missing" => {
131 fill = FillKind::Missing;
132 idx += 1;
133 continue;
134 }
135 "empty" => {
136 fill = FillKind::Empty;
137 idx += 1;
138 continue;
139 }
140 _ => {}
141 }
142 }
143 size_values.push(host);
144 idx += 1;
145 }
146
147 let dims = parse_size_values(size_values)?;
148 let mut shape = if let Some(dims) = dims {
149 normalize_dims(dims)
150 } else if let Some(proto) = like_proto.as_ref() {
151 prototype_shape(proto)?
152 } else {
153 vec![1, 1]
154 };
155
156 if shape.is_empty() {
157 shape = vec![0, 0];
158 }
159
160 Ok(ParsedStrings { shape, fill })
161}
162
163fn prototype_shape(value: &Value) -> BuiltinResult<Vec<usize>> {
164 match value {
165 Value::StringArray(sa) => Ok(sa.shape.clone()),
166 _ => shape_from_value(value, FN_NAME).map_err(strings_flow),
167 }
168}
169
170fn err_integer() -> RuntimeError {
171 strings_flow(format!("{FN_NAME}: {SIZE_INTEGER_ERR}"))
172}
173
174fn err_nonnegative() -> RuntimeError {
175 strings_flow(format!("{FN_NAME}: {SIZE_NONNEGATIVE_ERR}"))
176}
177
178fn err_finite() -> RuntimeError {
179 strings_flow(format!("{FN_NAME}: {SIZE_FINITE_ERR}"))
180}
181
182fn parse_size_values(values: Vec<Value>) -> BuiltinResult<Option<Vec<usize>>> {
183 match values.len() {
184 0 => Ok(None),
185 1 => parse_single_argument(values.into_iter().next().unwrap()).map(Some),
186 _ => {
187 let mut dims = Vec::with_capacity(values.len());
188 for value in &values {
189 dims.push(parse_size_scalar(value)?);
190 }
191 Ok(Some(dims))
192 }
193 }
194}
195
196fn parse_single_argument(value: Value) -> BuiltinResult<Vec<usize>> {
197 match value {
198 Value::Int(iv) => Ok(vec![validate_i64_dimension(iv.to_i64())?]),
199 Value::Num(n) => Ok(vec![parse_numeric_dimension(n)?]),
200 Value::Bool(b) => Ok(vec![if b { 1 } else { 0 }]),
201 Value::Tensor(t) => parse_size_tensor(&t),
202 Value::LogicalArray(arr) => parse_size_logical_array(&arr),
203 other => Err(strings_flow(format!(
204 "{FN_NAME}: {SIZE_NUMERIC_ERR}, got {other:?}"
205 ))),
206 }
207}
208
209fn parse_size_scalar(value: &Value) -> BuiltinResult<usize> {
210 match value {
211 Value::Int(iv) => {
212 let raw = iv.to_i64();
213 validate_i64_dimension(raw)
214 }
215 Value::Num(n) => parse_numeric_dimension(*n),
216 Value::Bool(b) => Ok(if *b { 1 } else { 0 }),
217 Value::Tensor(t) => {
218 if t.data.len() != 1 {
219 return Err(strings_flow(format!("{FN_NAME}: {SIZE_SCALAR_ERR}")));
220 }
221 parse_numeric_dimension(t.data[0])
222 }
223 Value::LogicalArray(arr) => {
224 if arr.data.len() != 1 {
225 return Err(strings_flow(format!("{FN_NAME}: {SIZE_SCALAR_ERR}")));
226 }
227 Ok(if arr.data[0] != 0 { 1 } else { 0 })
228 }
229 other => Err(strings_flow(format!(
230 "{FN_NAME}: {SIZE_NUMERIC_ERR}, got {other:?}"
231 ))),
232 }
233}
234
235fn parse_size_tensor(tensor: &Tensor) -> BuiltinResult<Vec<usize>> {
236 if tensor.data.is_empty() {
237 return Ok(vec![0, 0]);
238 }
239 if !is_vector_shape(&tensor.shape) {
240 return Err(strings_flow(format!(
241 "{FN_NAME}: size vector must be a row or column vector"
242 )));
243 }
244 tensor
245 .data
246 .iter()
247 .map(|&value| parse_numeric_dimension(value))
248 .collect()
249}
250
251fn parse_size_logical_array(array: &LogicalArray) -> BuiltinResult<Vec<usize>> {
252 if array.data.is_empty() {
253 return Ok(vec![0, 0]);
254 }
255 if !is_vector_shape(&array.shape) {
256 return Err(strings_flow(format!(
257 "{FN_NAME}: size vector must be a row or column vector"
258 )));
259 }
260 array
261 .data
262 .iter()
263 .map(|&value| Ok(if value != 0 { 1 } else { 0 }))
264 .collect()
265}
266
267fn parse_numeric_dimension(value: f64) -> BuiltinResult<usize> {
268 if !value.is_finite() {
269 return Err(err_finite());
270 }
271 let rounded = value.round();
272 if (rounded - value).abs() > f64::EPSILON {
273 return Err(err_integer());
274 }
275 if rounded < 0.0 {
276 return Err(err_nonnegative());
277 }
278 if rounded > usize::MAX as f64 {
279 return Err(strings_flow(format!(
280 "{FN_NAME}: requested dimension exceeds platform limits"
281 )));
282 }
283 Ok(rounded as usize)
284}
285
286fn normalize_dims(dims: Vec<usize>) -> Vec<usize> {
287 match dims.len() {
288 0 => vec![0, 0],
289 1 => {
290 let side = dims[0];
291 vec![side, side]
292 }
293 _ => dims,
294 }
295}
296
297fn is_vector_shape(shape: &[usize]) -> bool {
298 match shape.len() {
299 0 | 1 => true,
300 2 => shape[0] == 1 || shape[1] == 1,
301 _ => shape.iter().filter(|&&d| d > 1).count() <= 1,
302 }
303}
304
305fn validate_i64_dimension(raw: i64) -> BuiltinResult<usize> {
306 if raw < 0 {
307 return Err(err_nonnegative());
308 }
309 if (raw as u128) > (usize::MAX as u128) {
310 return Err(strings_flow(format!(
311 "{FN_NAME}: requested dimension exceeds platform limits"
312 )));
313 }
314 Ok(raw as usize)
315}
316
317#[cfg(test)]
318pub(crate) mod tests {
319 use super::*;
320
321 use crate::builtins::common::test_support;
322 use runmat_accelerate_api::HostTensorView;
323 use runmat_builtins::{ResolveContext, Type};
324
325 fn strings_builtin(rest: Vec<Value>) -> BuiltinResult<Value> {
326 futures::executor::block_on(super::strings_builtin(rest))
327 }
328
329 fn error_message(err: crate::RuntimeError) -> String {
330 err.message().to_string()
331 }
332
333 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
334 #[test]
335 fn strings_default_scalar() {
336 let result = strings_builtin(Vec::new()).expect("strings");
337 match result {
338 Value::StringArray(array) => {
339 assert_eq!(array.shape, vec![1, 1]);
340 assert_eq!(array.data, vec![String::new()]);
341 }
342 other => panic!("expected string array, got {other:?}"),
343 }
344 }
345
346 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
347 #[test]
348 fn strings_square_from_single_dimension() {
349 let args = vec![Value::Num(4.0)];
350 let result = strings_builtin(args).expect("strings");
351 match result {
352 Value::StringArray(array) => {
353 assert_eq!(array.shape, vec![4, 4]);
354 assert!(array.data.iter().all(|s| s.is_empty()));
355 }
356 other => panic!("expected string array, got {other:?}"),
357 }
358 }
359
360 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
361 #[test]
362 fn strings_rectangular_multiple_args() {
363 let args = vec![
364 Value::Int(runmat_builtins::IntValue::I32(2)),
365 Value::Num(3.0),
366 ];
367 let result = strings_builtin(args).expect("strings");
368 match result {
369 Value::StringArray(array) => {
370 assert_eq!(array.shape, vec![2, 3]);
371 assert_eq!(array.data.len(), 6);
372 }
373 other => panic!("expected string array, got {other:?}"),
374 }
375 }
376
377 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
378 #[test]
379 fn strings_from_size_vector_tensor() {
380 let dims = Tensor::new(vec![2.0, 3.0, 1.0], vec![1, 3]).unwrap();
381 let result = strings_builtin(vec![Value::Tensor(dims)]).expect("strings");
382 match result {
383 Value::StringArray(array) => {
384 assert_eq!(array.shape, vec![2, 3, 1]);
385 assert_eq!(array.data.len(), 6);
386 }
387 other => panic!("expected string array, got {other:?}"),
388 }
389 }
390
391 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
392 #[test]
393 fn strings_preserves_trailing_singletons() {
394 let args = vec![
395 Value::Num(3.0),
396 Value::Int(runmat_builtins::IntValue::I32(1)),
397 Value::Num(1.0),
398 Value::Bool(true),
399 ];
400 let result = strings_builtin(args).expect("strings");
401 match result {
402 Value::StringArray(array) => {
403 assert_eq!(array.shape, vec![3, 1, 1, 1]);
404 assert_eq!(array.data.len(), 3);
405 }
406 other => panic!("expected string array, got {other:?}"),
407 }
408 }
409
410 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
411 #[test]
412 fn strings_bool_dimensions() {
413 let result = strings_builtin(vec![Value::Bool(true), Value::Bool(false)]).expect("strings");
414 match result {
415 Value::StringArray(array) => {
416 assert_eq!(array.shape, vec![1, 0]);
417 assert!(array.data.is_empty());
418 }
419 other => panic!("expected string array, got {other:?}"),
420 }
421 }
422
423 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
424 #[test]
425 fn strings_logical_vector_argument() {
426 let logical =
427 LogicalArray::new(vec![1u8, 0, 1], vec![1, 3]).expect("logical size construction");
428 let result = strings_builtin(vec![Value::LogicalArray(logical)]).expect("strings");
429 match result {
430 Value::StringArray(array) => {
431 assert_eq!(array.shape, vec![1, 0, 1]);
432 assert!(array.data.is_empty());
433 }
434 other => panic!("expected string array, got {other:?}"),
435 }
436 }
437
438 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
439 #[test]
440 fn strings_negative_dimension_errors() {
441 let err =
442 error_message(strings_builtin(vec![Value::Num(-5.0)]).expect_err("expected error"));
443 assert!(err.contains(super::SIZE_NONNEGATIVE_ERR));
444 }
445
446 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
447 #[test]
448 fn strings_rejects_non_integer_dimension() {
449 let err =
450 error_message(strings_builtin(vec![Value::Num(2.5)]).expect_err("expected error"));
451 assert!(err.contains(super::SIZE_INTEGER_ERR));
452 }
453
454 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
455 #[test]
456 fn strings_rejects_non_numeric_dimension() {
457 let err = error_message(
458 strings_builtin(vec![Value::String("size".into())]).expect_err("expected error"),
459 );
460 assert!(err.contains("size arguments must be numeric"));
461 }
462
463 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
464 #[test]
465 fn strings_empty_vector_returns_empty_array() {
466 let dims = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
467 let result = strings_builtin(vec![Value::Tensor(dims)]).expect("strings");
468 match result {
469 Value::StringArray(array) => {
470 assert_eq!(array.shape, vec![0, 0]);
471 assert!(array.data.is_empty());
472 }
473 other => panic!("expected string array, got {other:?}"),
474 }
475 }
476
477 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
478 #[test]
479 fn strings_missing_option_fills_with_missing() {
480 let result = strings_builtin(vec![
481 Value::Num(2.0),
482 Value::Num(3.0),
483 Value::String("missing".into()),
484 ])
485 .expect("strings");
486 match result {
487 Value::StringArray(array) => {
488 assert_eq!(array.shape, vec![2, 3]);
489 assert_eq!(array.data.len(), 6);
490 assert!(array.data.iter().all(|s| s == "<missing>"));
491 }
492 other => panic!("expected string array, got {other:?}"),
493 }
494 }
495
496 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
497 #[test]
498 fn strings_missing_without_dims_defaults_to_scalar() {
499 let result = strings_builtin(vec![Value::String("missing".into())]).expect("strings");
500 match result {
501 Value::StringArray(array) => {
502 assert_eq!(array.shape, vec![1, 1]);
503 assert_eq!(array.data, vec!["<missing>".to_string()]);
504 }
505 other => panic!("expected string array, got {other:?}"),
506 }
507 }
508
509 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
510 #[test]
511 fn strings_like_prototype_shape() {
512 let proto = StringArray::new(
513 vec!["alpha".into(), "beta".into(), "gamma".into()],
514 vec![3, 1],
515 )
516 .unwrap();
517 let result = strings_builtin(vec![
518 Value::String("like".into()),
519 Value::StringArray(proto.clone()),
520 ])
521 .expect("strings");
522 match result {
523 Value::StringArray(array) => {
524 assert_eq!(array.shape, proto.shape);
525 assert!(array.data.iter().all(|s| s.is_empty()));
526 }
527 other => panic!("expected string array, got {other:?}"),
528 }
529 }
530
531 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
532 #[test]
533 fn strings_like_numeric_prototype() {
534 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
535 let result = strings_builtin(vec![
536 Value::String("like".into()),
537 Value::Tensor(tensor.clone()),
538 ])
539 .expect("strings");
540 match result {
541 Value::StringArray(array) => {
542 assert_eq!(array.shape, tensor.shape);
543 assert_eq!(array.data.len(), tensor.data.len());
544 }
545 other => panic!("expected string array, got {other:?}"),
546 }
547 }
548
549 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
550 #[test]
551 fn strings_like_overrides_shape_when_dims_provided() {
552 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
553 let result = strings_builtin(vec![
554 Value::String("like".into()),
555 Value::Tensor(tensor),
556 Value::Int(runmat_builtins::IntValue::I32(3)),
557 ])
558 .expect("strings");
559 match result {
560 Value::StringArray(array) => {
561 assert_eq!(array.shape, vec![3, 3]);
562 }
563 other => panic!("expected string array, got {other:?}"),
564 }
565 }
566
567 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
568 #[test]
569 fn strings_like_requires_prototype() {
570 let err = error_message(
571 strings_builtin(vec![Value::String("like".into())]).expect_err("expected error"),
572 );
573 assert!(err.contains("expected prototype"));
574 }
575
576 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
577 #[test]
578 fn strings_like_rejects_multiple_specs() {
579 let err = error_message(
580 strings_builtin(vec![
581 Value::String("like".into()),
582 Value::Num(1.0),
583 Value::String("like".into()),
584 Value::Num(2.0),
585 ])
586 .expect_err("expected error"),
587 );
588 assert!(err.contains("multiple 'like'"));
589 }
590
591 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
592 #[test]
593 fn strings_gpu_size_vector_argument() {
594 test_support::with_test_provider(|provider| {
595 let dims = Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap();
596 let view = HostTensorView {
597 data: &dims.data,
598 shape: &dims.shape,
599 };
600 let handle = provider.upload(&view).expect("upload");
601 let result = strings_builtin(vec![Value::GpuTensor(handle)]).expect("strings");
602 match result {
603 Value::StringArray(array) => {
604 assert_eq!(array.shape, vec![2, 3]);
605 assert_eq!(array.data.len(), 6);
606 }
607 other => panic!("expected string array, got {other:?}"),
608 }
609 });
610 }
611
612 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
613 #[test]
614 fn strings_like_accepts_gpu_prototype() {
615 test_support::with_test_provider(|provider| {
616 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
617 let view = HostTensorView {
618 data: &tensor.data,
619 shape: &tensor.shape,
620 };
621 let handle = provider.upload(&view).expect("upload");
622 let result =
623 strings_builtin(vec![Value::String("like".into()), Value::GpuTensor(handle)])
624 .expect("strings");
625 match result {
626 Value::StringArray(array) => {
627 assert_eq!(array.shape, vec![2, 2]);
628 }
629 other => panic!("expected string array, got {other:?}"),
630 }
631 });
632 }
633
634 #[test]
635 fn strings_type_is_string_array() {
636 assert_eq!(
637 string_array_type(&[Type::Num], &ResolveContext::new(Vec::new())),
638 Type::cell_of(Type::String)
639 );
640 }
641
642 #[cfg(feature = "wgpu")]
643 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
644 #[test]
645 fn strings_handles_wgpu_size_vectors() {
646 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
647 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
648 );
649 let dims = Tensor::new(vec![1.0, 4.0], vec![1, 2]).unwrap();
650 let view = HostTensorView {
651 data: &dims.data,
652 shape: &dims.shape,
653 };
654 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
655 let handle = provider.upload(&view).expect("upload");
656 let result = strings_builtin(vec![Value::GpuTensor(handle)]).expect("strings");
657 match result {
658 Value::StringArray(array) => {
659 assert_eq!(array.shape, vec![1, 4]);
660 }
661 other => panic!("expected string array, got {other:?}"),
662 }
663 }
664}