Skip to main content

runmat_runtime/builtins/strings/core/
compose.rs

1//! MATLAB-compatible `compose` builtin that formats data into string arrays.
2use runmat_builtins::{StringArray, Value};
3use runmat_macros::runtime_builtin;
4
5use crate::builtins::common::map_control_flow_with_builtin;
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::builtins::strings::core::string::{
11    extract_format_spec, format_from_spec, FormatSpecData,
12};
13use crate::builtins::strings::type_resolvers::string_array_type;
14use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
15
16#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::compose")]
17pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
18    name: "compose",
19    op_kind: GpuOpKind::Custom("format"),
20    supported_precisions: &[],
21    broadcast: BroadcastSemantics::None,
22    provider_hooks: &[],
23    constant_strategy: ConstantStrategy::InlineLiteral,
24    residency: ResidencyPolicy::GatherImmediately,
25    nan_mode: ReductionNaN::Include,
26    two_pass_threshold: None,
27    workgroup_size: None,
28    accepts_nan_mode: false,
29    notes: "Formatting always executes on the CPU; GPU tensors are gathered before substitution.",
30};
31
32#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::compose")]
33pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
34    name: "compose",
35    shape: ShapeRequirements::Any,
36    constant_strategy: ConstantStrategy::InlineLiteral,
37    elementwise: None,
38    reduction: None,
39    emits_nan: false,
40    notes: "Formatting builtin; not eligible for fusion and materialises host string arrays.",
41};
42
43fn compose_flow(message: impl Into<String>) -> RuntimeError {
44    build_runtime_error(message).with_builtin("compose").build()
45}
46
47fn remap_compose_flow(mut err: RuntimeError) -> RuntimeError {
48    err = map_control_flow_with_builtin(err, "compose");
49    if let Some(message) = err.message.strip_prefix("string: ") {
50        err.message = format!("compose: {message}");
51        return err;
52    }
53    if !err.message.starts_with("compose: ") {
54        err.message = format!("compose: {}", err.message);
55    }
56    err
57}
58
59#[runtime_builtin(
60    name = "compose",
61    category = "strings/core",
62    summary = "Format values into MATLAB string arrays using printf-style placeholders.",
63    keywords = "compose,format,string array,gpu",
64    accel = "sink",
65    type_resolver(string_array_type),
66    builtin_path = "crate::builtins::strings::core::compose"
67)]
68async fn compose_builtin(format_spec: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
69    let format_value = gather_if_needed_async(&format_spec)
70        .await
71        .map_err(remap_compose_flow)?;
72    let mut gathered_args = Vec::with_capacity(rest.len());
73    for arg in rest {
74        let gathered = gather_if_needed_async(&arg)
75            .await
76            .map_err(remap_compose_flow)?;
77        gathered_args.push(gathered);
78    }
79
80    if gathered_args.is_empty() {
81        let spec = extract_format_spec(format_value)
82            .await
83            .map_err(remap_compose_flow)?;
84        let array = format_spec_data_to_string_array(spec)?;
85        return Ok(Value::StringArray(array));
86    }
87
88    let formatted = format_from_spec(format_value, gathered_args)
89        .await
90        .map_err(remap_compose_flow)?;
91    Ok(Value::StringArray(formatted))
92}
93
94fn format_spec_data_to_string_array(spec: FormatSpecData) -> BuiltinResult<StringArray> {
95    let shape = if spec.shape.is_empty() {
96        match spec.specs.len() {
97            0 => vec![0, 0],
98            1 => vec![1, 1],
99            len => vec![len, 1],
100        }
101    } else {
102        spec.shape
103    };
104    StringArray::new(spec.specs, shape).map_err(|e| compose_flow(format!("compose: {e}")))
105}
106
107#[cfg(test)]
108pub(crate) mod tests {
109    use super::*;
110    use crate::builtins::common::test_support;
111    use runmat_builtins::{IntValue, ResolveContext, Tensor, Type};
112
113    fn compose_builtin(format_spec: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
114        futures::executor::block_on(super::compose_builtin(format_spec, rest))
115    }
116
117    fn error_message(err: crate::RuntimeError) -> String {
118        err.message().to_string()
119    }
120
121    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
122    #[test]
123    fn compose_scalar_numeric() {
124        let result = compose_builtin(Value::from("Count %d"), vec![Value::Int(IntValue::I32(7))])
125            .expect("compose");
126        match result {
127            Value::StringArray(sa) => {
128                assert_eq!(sa.shape, vec![1, 1]);
129                assert_eq!(sa.data, vec!["Count 7".to_string()]);
130            }
131            other => panic!("expected string array, got {other:?}"),
132        }
133    }
134
135    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
136    #[test]
137    fn compose_broadcasts_scalar_spec() {
138        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
139        let result = compose_builtin(Value::from("Item %0.0f"), vec![Value::Tensor(tensor)])
140            .expect("compose");
141        match result {
142            Value::StringArray(sa) => {
143                assert_eq!(sa.shape, vec![1, 2]);
144                assert_eq!(sa.data, vec!["Item 1".to_string(), "Item 2".to_string()]);
145            }
146            other => panic!("expected string array, got {other:?}"),
147        }
148    }
149
150    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
151    #[test]
152    fn compose_zero_arguments_returns_spec() {
153        let spec = Value::StringArray(
154            StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap(),
155        );
156        let result = compose_builtin(spec, Vec::new()).expect("compose");
157        match result {
158            Value::StringArray(sa) => {
159                assert_eq!(sa.shape, vec![1, 2]);
160                assert_eq!(sa.data, vec!["alpha".to_string(), "beta".to_string()]);
161            }
162            other => panic!("expected string array, got {other:?}"),
163        }
164    }
165
166    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
167    #[test]
168    fn compose_mismatched_lengths_errors() {
169        let spec = Value::StringArray(
170            StringArray::new(vec!["%d".into(), "%d".into()], vec![1, 2]).unwrap(),
171        );
172        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
173        let err = error_message(compose_builtin(spec, vec![Value::Tensor(tensor)]).unwrap_err());
174        assert!(
175            err.starts_with("compose: "),
176            "expected compose prefix, got {err}"
177        );
178        assert!(
179            err.contains("format data arguments must be scalars or match formatSpec size"),
180            "unexpected error text: {err}"
181        );
182    }
183
184    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
185    #[test]
186    fn compose_gpu_argument() {
187        test_support::with_test_provider(|provider| {
188            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
189            let view = runmat_accelerate_api::HostTensorView {
190                data: &tensor.data,
191                shape: &tensor.shape,
192            };
193            let handle = provider.upload(&view).expect("upload");
194            let result =
195                compose_builtin(Value::from("Value %0.0f"), vec![Value::GpuTensor(handle)])
196                    .expect("compose");
197            match result {
198                Value::StringArray(sa) => {
199                    assert_eq!(sa.shape, vec![1, 3]);
200                    assert_eq!(
201                        sa.data,
202                        vec![
203                            "Value 1".to_string(),
204                            "Value 2".to_string(),
205                            "Value 3".to_string()
206                        ]
207                    );
208                }
209                other => panic!("expected string array, got {other:?}"),
210            }
211        });
212    }
213
214    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
215    #[test]
216    #[cfg(feature = "wgpu")]
217    fn compose_wgpu_numeric_tensor_matches_cpu() {
218        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
219            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
220        );
221        let tensor = Tensor::new(vec![1.25, 2.5, 3.75], vec![1, 3]).unwrap();
222        let cpu = compose_builtin(
223            Value::from("Value %0.2f"),
224            vec![Value::Tensor(tensor.clone())],
225        )
226        .expect("cpu compose");
227        let view = runmat_accelerate_api::HostTensorView {
228            data: &tensor.data,
229            shape: &tensor.shape,
230        };
231        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
232        let handle = provider.upload(&view).expect("gpu upload");
233        let gpu = compose_builtin(Value::from("Value %0.2f"), vec![Value::GpuTensor(handle)])
234            .expect("gpu compose");
235        match (cpu, gpu) {
236            (Value::StringArray(expect), Value::StringArray(actual)) => {
237                assert_eq!(actual.shape, expect.shape);
238                assert_eq!(actual.data, expect.data);
239            }
240            other => panic!("unexpected results {other:?}"),
241        }
242    }
243
244    #[test]
245    fn compose_type_is_string_array() {
246        assert_eq!(
247            string_array_type(&[Type::String], &ResolveContext::new(Vec::new())),
248            Type::cell_of(Type::String)
249        );
250    }
251}