Skip to main content

runmat_runtime/builtins/strings/core/
compose.rs

1//! MATLAB-compatible `compose` builtin that formats data into string arrays.
2use runmat_builtins::{
3    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
4    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
5    StringArray, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::map_control_flow_with_builtin;
10use crate::builtins::common::spec::{
11    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
12    ReductionNaN, ResidencyPolicy, ShapeRequirements,
13};
14use crate::builtins::strings::core::string::{
15    extract_format_spec, format_from_spec, FormatSpecData,
16};
17use crate::builtins::strings::type_resolvers::string_array_type;
18use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::compose")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "compose",
23    op_kind: GpuOpKind::Custom("format"),
24    supported_precisions: &[],
25    broadcast: BroadcastSemantics::None,
26    provider_hooks: &[],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::GatherImmediately,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Formatting always executes on the CPU; GPU tensors are gathered before substitution.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::compose")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "compose",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "Formatting builtin; not eligible for fusion and materialises host string arrays.",
45};
46
47const BUILTIN_NAME: &str = "compose";
48
49const COMPOSE_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50    name: "S",
51    ty: BuiltinParamType::Any,
52    arity: BuiltinParamArity::Required,
53    default: None,
54    description: "Formatted string array output.",
55}];
56
57const COMPOSE_INPUT_NO_ARGS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
58    name: "formatSpec",
59    ty: BuiltinParamType::Any,
60    arity: BuiltinParamArity::Required,
61    default: None,
62    description: "Format text or array returned as strings when no data args are supplied.",
63}];
64
65const COMPOSE_INPUT_WITH_ARGS: [BuiltinParamDescriptor; 2] = [
66    BuiltinParamDescriptor {
67        name: "formatSpec",
68        ty: BuiltinParamType::Any,
69        arity: BuiltinParamArity::Required,
70        default: None,
71        description: "Format template text.",
72    },
73    BuiltinParamDescriptor {
74        name: "A...",
75        ty: BuiltinParamType::Any,
76        arity: BuiltinParamArity::Variadic,
77        default: None,
78        description: "Values substituted into formatSpec placeholders.",
79    },
80];
81
82const COMPOSE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
83    BuiltinSignatureDescriptor {
84        label: "S = compose(formatSpec)",
85        inputs: &COMPOSE_INPUT_NO_ARGS,
86        outputs: &COMPOSE_OUTPUT,
87    },
88    BuiltinSignatureDescriptor {
89        label: "S = compose(formatSpec, A...)",
90        inputs: &COMPOSE_INPUT_WITH_ARGS,
91        outputs: &COMPOSE_OUTPUT,
92    },
93];
94
95const COMPOSE_ERROR_INVALID_FORMAT_SPEC: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
96    code: "RM.COMPOSE.INVALID_FORMAT_SPEC",
97    identifier: Some("RunMat:compose:InvalidFormatSpec"),
98    when: "formatSpec is not valid text input for compose formatting.",
99    message: "compose: invalid formatSpec",
100};
101
102const COMPOSE_ERROR_ARGUMENT_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
103    code: "RM.COMPOSE.ARGUMENT_MISMATCH",
104    identifier: Some("RunMat:compose:ArgumentMismatch"),
105    when: "Data arguments are not scalar or broadcast-compatible with formatSpec.",
106    message: "compose: format data arguments must be scalars or match formatSpec size",
107};
108
109const COMPOSE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
110    code: "RM.COMPOSE.INTERNAL",
111    identifier: Some("RunMat:compose:InternalError"),
112    when: "Internal string-array construction failed.",
113    message: "compose: internal error",
114};
115
116const COMPOSE_ERRORS: [BuiltinErrorDescriptor; 3] = [
117    COMPOSE_ERROR_INVALID_FORMAT_SPEC,
118    COMPOSE_ERROR_ARGUMENT_MISMATCH,
119    COMPOSE_ERROR_INTERNAL,
120];
121
122pub const COMPOSE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
123    signatures: &COMPOSE_SIGNATURES,
124    output_mode: BuiltinOutputMode::Fixed,
125    completion_policy: BuiltinCompletionPolicy::Public,
126    errors: &COMPOSE_ERRORS,
127};
128
129fn compose_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
130    compose_error_with_message(error.message, error)
131}
132
133fn compose_error_with_message(
134    message: impl Into<String>,
135    error: &'static BuiltinErrorDescriptor,
136) -> RuntimeError {
137    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
138    if let Some(identifier) = error.identifier {
139        builder = builder.with_identifier(identifier);
140    }
141    builder.build()
142}
143
144fn remap_compose_flow(mut err: RuntimeError) -> RuntimeError {
145    err = map_control_flow_with_builtin(err, BUILTIN_NAME);
146    if let Some(message) = err.message.strip_prefix("string: ") {
147        err.message = format!("compose: {message}");
148        return err;
149    }
150    if !err.message.starts_with("compose: ") {
151        err.message = format!("compose: {}", err.message);
152    }
153    err
154}
155
156#[runtime_builtin(
157    name = "compose",
158    category = "strings/core",
159    summary = "Format values into string arrays using printf-style placeholders.",
160    keywords = "compose,format,string array,gpu",
161    accel = "sink",
162    type_resolver(string_array_type),
163    descriptor(crate::builtins::strings::core::compose::COMPOSE_DESCRIPTOR),
164    builtin_path = "crate::builtins::strings::core::compose"
165)]
166async fn compose_builtin(format_spec: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
167    let format_value = gather_if_needed_async(&format_spec)
168        .await
169        .map_err(remap_compose_flow)?;
170    let mut gathered_args = Vec::with_capacity(rest.len());
171    for arg in rest {
172        let gathered = gather_if_needed_async(&arg)
173            .await
174            .map_err(remap_compose_flow)?;
175        gathered_args.push(gathered);
176    }
177
178    if gathered_args.is_empty() {
179        let spec = extract_format_spec(format_value)
180            .await
181            .map_err(remap_compose_flow)?;
182        let array = format_spec_data_to_string_array(spec)?;
183        return Ok(Value::StringArray(array));
184    }
185
186    let formatted = format_from_spec(format_value, gathered_args)
187        .await
188        .map_err(remap_compose_flow)?;
189    Ok(Value::StringArray(formatted))
190}
191
192fn format_spec_data_to_string_array(spec: FormatSpecData) -> BuiltinResult<StringArray> {
193    let shape = if spec.shape.is_empty() {
194        match spec.specs.len() {
195            0 => vec![0, 0],
196            1 => vec![1, 1],
197            len => vec![len, 1],
198        }
199    } else {
200        spec.shape
201    };
202    StringArray::new(spec.specs, shape).map_err(|_| compose_error(&COMPOSE_ERROR_INTERNAL))
203}
204
205#[cfg(test)]
206pub(crate) mod tests {
207    use super::*;
208    use crate::builtins::common::test_support;
209    use runmat_builtins::{IntValue, ResolveContext, Tensor, Type};
210
211    fn compose_builtin(format_spec: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
212        futures::executor::block_on(super::compose_builtin(format_spec, rest))
213    }
214
215    fn error_message(err: crate::RuntimeError) -> String {
216        err.message().to_string()
217    }
218
219    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
220    #[test]
221    fn compose_scalar_numeric() {
222        let result = compose_builtin(Value::from("Count %d"), vec![Value::Int(IntValue::I32(7))])
223            .expect("compose");
224        match result {
225            Value::StringArray(sa) => {
226                assert_eq!(sa.shape, vec![1, 1]);
227                assert_eq!(sa.data, vec!["Count 7".to_string()]);
228            }
229            other => panic!("expected string array, got {other:?}"),
230        }
231    }
232
233    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
234    #[test]
235    fn compose_broadcasts_scalar_spec() {
236        let tensor = Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
237        let result = compose_builtin(Value::from("Item %0.0f"), vec![Value::Tensor(tensor)])
238            .expect("compose");
239        match result {
240            Value::StringArray(sa) => {
241                assert_eq!(sa.shape, vec![1, 2]);
242                assert_eq!(sa.data, vec!["Item 1".to_string(), "Item 2".to_string()]);
243            }
244            other => panic!("expected string array, got {other:?}"),
245        }
246    }
247
248    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
249    #[test]
250    fn compose_zero_arguments_returns_spec() {
251        let spec = Value::StringArray(
252            StringArray::new(vec!["alpha".into(), "beta".into()], vec![1, 2]).unwrap(),
253        );
254        let result = compose_builtin(spec, Vec::new()).expect("compose");
255        match result {
256            Value::StringArray(sa) => {
257                assert_eq!(sa.shape, vec![1, 2]);
258                assert_eq!(sa.data, vec!["alpha".to_string(), "beta".to_string()]);
259            }
260            other => panic!("expected string array, got {other:?}"),
261        }
262    }
263
264    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
265    #[test]
266    fn compose_mismatched_lengths_errors() {
267        let spec = Value::StringArray(
268            StringArray::new(vec!["%d".into(), "%d".into()], vec![1, 2]).unwrap(),
269        );
270        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
271        let err = error_message(compose_builtin(spec, vec![Value::Tensor(tensor)]).unwrap_err());
272        assert!(
273            err.starts_with("compose: "),
274            "expected compose prefix, got {err}"
275        );
276        assert!(
277            err.contains("format data arguments must be scalars or match formatSpec size"),
278            "unexpected error text: {err}"
279        );
280    }
281
282    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
283    #[test]
284    fn compose_gpu_argument() {
285        test_support::with_test_provider(|provider| {
286            let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
287            let view = runmat_accelerate_api::HostTensorView {
288                data: &tensor.data,
289                shape: &tensor.shape,
290            };
291            let handle = provider.upload(&view).expect("upload");
292            let result =
293                compose_builtin(Value::from("Value %0.0f"), vec![Value::GpuTensor(handle)])
294                    .expect("compose");
295            match result {
296                Value::StringArray(sa) => {
297                    assert_eq!(sa.shape, vec![1, 3]);
298                    assert_eq!(
299                        sa.data,
300                        vec![
301                            "Value 1".to_string(),
302                            "Value 2".to_string(),
303                            "Value 3".to_string()
304                        ]
305                    );
306                }
307                other => panic!("expected string array, got {other:?}"),
308            }
309        });
310    }
311
312    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
313    #[test]
314    #[cfg(feature = "wgpu")]
315    fn compose_wgpu_numeric_tensor_matches_cpu() {
316        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
317            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
318        );
319        let tensor = Tensor::new(vec![1.25, 2.5, 3.75], vec![1, 3]).unwrap();
320        let cpu = compose_builtin(
321            Value::from("Value %0.2f"),
322            vec![Value::Tensor(tensor.clone())],
323        )
324        .expect("cpu compose");
325        let view = runmat_accelerate_api::HostTensorView {
326            data: &tensor.data,
327            shape: &tensor.shape,
328        };
329        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
330        let handle = provider.upload(&view).expect("gpu upload");
331        let gpu = compose_builtin(Value::from("Value %0.2f"), vec![Value::GpuTensor(handle)])
332            .expect("gpu compose");
333        match (cpu, gpu) {
334            (Value::StringArray(expect), Value::StringArray(actual)) => {
335                assert_eq!(actual.shape, expect.shape);
336                assert_eq!(actual.data, expect.data);
337            }
338            other => panic!("unexpected results {other:?}"),
339        }
340    }
341
342    #[test]
343    fn compose_type_is_string_array() {
344        assert_eq!(
345            string_array_type(&[Type::String], &ResolveContext::new(Vec::new())),
346            Type::cell_of(Type::String)
347        );
348    }
349}