runmat_runtime/builtins/strings/core/
compose.rs1use runmat_builtins::{StringArray, Value};
3use runmat_macros::runtime_builtin;
4
5use crate::builtins::common::map_control_flow_with_builtin;
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::strings::core::string::{
11 extract_format_spec, format_from_spec, FormatSpecData,
12};
13use crate::builtins::strings::type_resolvers::string_array_type;
14use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::compose")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18 name: "compose",
19 op_kind: GpuOpKind::Custom("format"),
20 supported_precisions: &[],
21 broadcast: BroadcastSemantics::None,
22 provider_hooks: &[],
23 constant_strategy: ConstantStrategy::InlineLiteral,
24 residency: ResidencyPolicy::GatherImmediately,
25 nan_mode: ReductionNaN::Include,
26 two_pass_threshold: None,
27 workgroup_size: None,
28 accepts_nan_mode: false,
29 notes: "Formatting always executes on the CPU; GPU tensors are gathered before substitution.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::compose")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34 name: "compose",
35 shape: ShapeRequirements::Any,
36 constant_strategy: ConstantStrategy::InlineLiteral,
37 elementwise: None,
38 reduction: None,
39 emits_nan: false,
40 notes: "Formatting builtin; not eligible for fusion and materialises host string arrays.",
41};
42
43fn compose_flow(message: impl Into<String>) -> RuntimeError {
44 build_runtime_error(message).with_builtin("compose").build()
45}
46
47fn remap_compose_flow(mut err: RuntimeError) -> RuntimeError {
48 err = map_control_flow_with_builtin(err, "compose");
49 if let Some(message) = err.message.strip_prefix("string: ") {
50 err.message = format!("compose: {message}");
51 return err;
52 }
53 if !err.message.starts_with("compose: ") {
54 err.message = format!("compose: {}", err.message);
55 }
56 err
57}
58
59#[runtime_builtin(
60 name = "compose",
61 category = "strings/core",
62 summary = "Format values into MATLAB string arrays using printf-style placeholders.",
63 keywords = "compose,format,string array,gpu",
64 accel = "sink",
65 type_resolver(string_array_type),
66 builtin_path = "crate::builtins::strings::core::compose"
67)]
68async fn compose_builtin(format_spec: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
69 let format_value = gather_if_needed_async(&format_spec)
70 .await
71 .map_err(remap_compose_flow)?;
72 let mut gathered_args = Vec::with_capacity(rest.len());
73 for arg in rest {
74 let gathered = gather_if_needed_async(&arg)
75 .await
76 .map_err(remap_compose_flow)?;
77 gathered_args.push(gathered);
78 }
79
80 if gathered_args.is_empty() {
81 let spec = extract_format_spec(format_value)
82 .await
83 .map_err(remap_compose_flow)?;
84 let array = format_spec_data_to_string_array(spec)?;
85 return Ok(Value::StringArray(array));
86 }
87
88 let formatted = format_from_spec(format_value, gathered_args)
89 .await
90 .map_err(remap_compose_flow)?;
91 Ok(Value::StringArray(formatted))
92}
93
94fn format_spec_data_to_string_array(spec: FormatSpecData) -> BuiltinResult<StringArray> {
95 let shape = if spec.shape.is_empty() {
96 match spec.specs.len() {
97 0 => vec![0, 0],
98 1 => vec![1, 1],
99 len => vec![len, 1],
100 }
101 } else {
102 spec.shape
103 };
104 StringArray::new(spec.specs, shape).map_err(|e| compose_flow(format!("compose: {e}")))
105}
106
107#[cfg(test)]
108pub(crate) mod tests {
109 use super::*;
110 use crate::builtins::common::test_support;
111 use runmat_builtins::{IntValue, ResolveContext, Tensor, Type};
112
113 fn compose_builtin(format_spec: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
114 futures::executor::block_on(super::compose_builtin(format_spec, rest))
115 }
116
117 fn error_message(err: crate::RuntimeError) -> String {
118 err.message().to_string()
119 }
120
121 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
122 #[test]
123 fn compose_scalar_numeric() {
124 let result = compose_builtin(Value::from("Count %d"), vec![Value::Int(IntValue::I32(7))])
125 .expect("compose");
126 match result {
127 Value::StringArray(sa) => {
128 assert_eq!(sa.shape, vec![1, 1]);
129 assert_eq!(sa.data, vec!["Count 7".to_string()]);
130 }
131 other => panic!("expected string array, got {other:?}"),
132 }
133 }
134
135 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
136 #[test]
137 fn compose_broadcasts_scalar_spec() {
138 let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
139 let result = compose_builtin(Value::from("Item %0.0f"), vec![Value::Tensor(tensor)])
140 .expect("compose");
141 match result {
142 Value::StringArray(sa) => {
143 assert_eq!(sa.shape, vec![1, 2]);
144 assert_eq!(sa.data, vec!["Item 1".to_string(), "Item 2".to_string()]);
145 }
146 other => panic!("expected string array, got {other:?}"),
147 }
148 }
149
150 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
151 #[test]
152 fn compose_zero_arguments_returns_spec() {
153 let spec = Value::StringArray(
154 StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap(),
155 );
156 let result = compose_builtin(spec, Vec::new()).expect("compose");
157 match result {
158 Value::StringArray(sa) => {
159 assert_eq!(sa.shape, vec![1, 2]);
160 assert_eq!(sa.data, vec!["alpha".to_string(), "beta".to_string()]);
161 }
162 other => panic!("expected string array, got {other:?}"),
163 }
164 }
165
166 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
167 #[test]
168 fn compose_mismatched_lengths_errors() {
169 let spec = Value::StringArray(
170 StringArray::new(vec!["%d".into(), "%d".into()], vec![1, 2]).unwrap(),
171 );
172 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
173 let err = error_message(compose_builtin(spec, vec![Value::Tensor(tensor)]).unwrap_err());
174 assert!(
175 err.starts_with("compose: "),
176 "expected compose prefix, got {err}"
177 );
178 assert!(
179 err.contains("format data arguments must be scalars or match formatSpec size"),
180 "unexpected error text: {err}"
181 );
182 }
183
184 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
185 #[test]
186 fn compose_gpu_argument() {
187 test_support::with_test_provider(|provider| {
188 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
189 let view = runmat_accelerate_api::HostTensorView {
190 data: &tensor.data,
191 shape: &tensor.shape,
192 };
193 let handle = provider.upload(&view).expect("upload");
194 let result =
195 compose_builtin(Value::from("Value %0.0f"), vec![Value::GpuTensor(handle)])
196 .expect("compose");
197 match result {
198 Value::StringArray(sa) => {
199 assert_eq!(sa.shape, vec![1, 3]);
200 assert_eq!(
201 sa.data,
202 vec![
203 "Value 1".to_string(),
204 "Value 2".to_string(),
205 "Value 3".to_string()
206 ]
207 );
208 }
209 other => panic!("expected string array, got {other:?}"),
210 }
211 });
212 }
213
214 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
215 #[test]
216 #[cfg(feature = "wgpu")]
217 fn compose_wgpu_numeric_tensor_matches_cpu() {
218 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
219 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
220 );
221 let tensor = Tensor::new(vec![1.25, 2.5, 3.75], vec![1, 3]).unwrap();
222 let cpu = compose_builtin(
223 Value::from("Value %0.2f"),
224 vec![Value::Tensor(tensor.clone())],
225 )
226 .expect("cpu compose");
227 let view = runmat_accelerate_api::HostTensorView {
228 data: &tensor.data,
229 shape: &tensor.shape,
230 };
231 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
232 let handle = provider.upload(&view).expect("gpu upload");
233 let gpu = compose_builtin(Value::from("Value %0.2f"), vec![Value::GpuTensor(handle)])
234 .expect("gpu compose");
235 match (cpu, gpu) {
236 (Value::StringArray(expect), Value::StringArray(actual)) => {
237 assert_eq!(actual.shape, expect.shape);
238 assert_eq!(actual.data, expect.data);
239 }
240 other => panic!("unexpected results {other:?}"),
241 }
242 }
243
244 #[test]
245 fn compose_type_is_string_array() {
246 assert_eq!(
247 string_array_type(&[Type::String], &ResolveContext::new(Vec::new())),
248 Type::cell_of(Type::String)
249 );
250 }
251}