Skip to main content

runmat_runtime/builtins/strings/core/
strcmp.rs

1//! MATLAB-compatible `strcmp` builtin for RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::common::tensor;
16use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
17use crate::builtins::strings::type_resolvers::logical_text_match_type;
18use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strcmp")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "strcmp",
23    op_kind: GpuOpKind::Custom("string-compare"),
24    supported_precisions: &[],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::GatherImmediately,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Performs host-side text comparisons; GPU operands are gathered automatically before evaluation.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strcmp")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "strcmp",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "Produces logical results on the host; not eligible for GPU fusion.",
45};
46
47const BUILTIN_NAME: &str = "strcmp";
48
49const STRCMP_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50    name: "tf",
51    ty: BuiltinParamType::LogicalArray,
52    arity: BuiltinParamArity::Required,
53    default: None,
54    description: "Logical comparison result.",
55}];
56
57const STRCMP_INPUTS: [BuiltinParamDescriptor; 2] = [
58    BuiltinParamDescriptor {
59        name: "A",
60        ty: BuiltinParamType::Any,
61        arity: BuiltinParamArity::Required,
62        default: None,
63        description: "First text input (string/char/cell/string array).",
64    },
65    BuiltinParamDescriptor {
66        name: "B",
67        ty: BuiltinParamType::Any,
68        arity: BuiltinParamArity::Required,
69        default: None,
70        description: "Second text input (string/char/cell/string array).",
71    },
72];
73
74const STRCMP_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
75    label: "tf = strcmp(A, B)",
76    inputs: &STRCMP_INPUTS,
77    outputs: &STRCMP_OUTPUT,
78}];
79
80const STRCMP_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
81    code: "RM.STRCMP.INVALID_INPUT",
82    identifier: Some("RunMat:strcmp:InvalidInput"),
83    when: "At least one input is not a supported text container.",
84    message: "strcmp: text inputs must be string/char/cell/string-array values",
85};
86
87const STRCMP_ERROR_SHAPE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
88    code: "RM.STRCMP.SHAPE_MISMATCH",
89    identifier: Some("RunMat:strcmp:ShapeMismatch"),
90    when: "Inputs are not broadcast-compatible for elementwise comparison.",
91    message: "strcmp: input sizes are not broadcast-compatible",
92};
93
94const STRCMP_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
95    code: "RM.STRCMP.INTERNAL",
96    identifier: Some("RunMat:strcmp:InternalError"),
97    when: "Internal logical result assembly failed.",
98    message: "strcmp: internal error",
99};
100
101const STRCMP_ERRORS: [BuiltinErrorDescriptor; 3] = [
102    STRCMP_ERROR_INVALID_INPUT,
103    STRCMP_ERROR_SHAPE_MISMATCH,
104    STRCMP_ERROR_INTERNAL,
105];
106
107pub const STRCMP_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
108    signatures: &STRCMP_SIGNATURES,
109    output_mode: BuiltinOutputMode::Fixed,
110    completion_policy: BuiltinCompletionPolicy::Public,
111    errors: &STRCMP_ERRORS,
112};
113
114fn strcmp_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
115    strcmp_error_with_message(error.message, error)
116}
117
118fn strcmp_error_with_message(
119    message: impl Into<String>,
120    error: &'static BuiltinErrorDescriptor,
121) -> RuntimeError {
122    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
123    if let Some(identifier) = error.identifier {
124        builder = builder.with_identifier(identifier);
125    }
126    builder.build()
127}
128
129fn remap_strcmp_flow(err: RuntimeError) -> RuntimeError {
130    map_control_flow_with_builtin(err, BUILTIN_NAME)
131}
132
133#[runtime_builtin(
134    name = "strcmp",
135    category = "strings/core",
136    summary = "Compare text inputs for exact case-sensitive equality.",
137    keywords = "strcmp,string compare,text equality",
138    accel = "sink",
139    type_resolver(logical_text_match_type),
140    descriptor(crate::builtins::strings::core::strcmp::STRCMP_DESCRIPTOR),
141    builtin_path = "crate::builtins::strings::core::strcmp"
142)]
143async fn strcmp_builtin(a: Value, b: Value) -> crate::BuiltinResult<Value> {
144    let a = gather_if_needed_async(&a)
145        .await
146        .map_err(remap_strcmp_flow)?;
147    let b = gather_if_needed_async(&b)
148        .await
149        .map_err(remap_strcmp_flow)?;
150    let left = TextCollection::from_argument(BUILTIN_NAME, a, "first argument")
151        .map_err(|_| strcmp_error(&STRCMP_ERROR_INVALID_INPUT))?;
152    let right = TextCollection::from_argument(BUILTIN_NAME, b, "second argument")
153        .map_err(|_| strcmp_error(&STRCMP_ERROR_INVALID_INPUT))?;
154    evaluate_strcmp(&left, &right)
155}
156
157fn evaluate_strcmp(left: &TextCollection, right: &TextCollection) -> BuiltinResult<Value> {
158    let shape = broadcast_shapes(BUILTIN_NAME, &left.shape, &right.shape)
159        .map_err(|_| strcmp_error(&STRCMP_ERROR_SHAPE_MISMATCH))?;
160    let total = tensor::element_count(&shape);
161    if total == 0 {
162        return logical_result(BUILTIN_NAME, Vec::new(), shape)
163            .map_err(|_| strcmp_error(&STRCMP_ERROR_INTERNAL));
164    }
165    let left_strides = compute_strides(&left.shape);
166    let right_strides = compute_strides(&right.shape);
167    let mut data = Vec::with_capacity(total);
168    for linear in 0..total {
169        let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
170        let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
171        let equal = match (&left.elements[li], &right.elements[ri]) {
172            (TextElement::Missing, _) => false,
173            (_, TextElement::Missing) => false,
174            (TextElement::Text(lhs), TextElement::Text(rhs)) => lhs == rhs,
175        };
176        data.push(if equal { 1 } else { 0 });
177    }
178    logical_result(BUILTIN_NAME, data, shape).map_err(|_| strcmp_error(&STRCMP_ERROR_INTERNAL))
179}
180
181#[cfg(test)]
182pub(crate) mod tests {
183    use super::*;
184    use crate::RuntimeError;
185    use runmat_builtins::{CellArray, CharArray, LogicalArray, ResolveContext, StringArray, Type};
186
187    fn strcmp_builtin(a: Value, b: Value) -> BuiltinResult<Value> {
188        futures::executor::block_on(super::strcmp_builtin(a, b))
189    }
190
191    fn error_message(err: RuntimeError) -> String {
192        err.to_string()
193    }
194
195    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
196    #[test]
197    fn strcmp_string_scalar_true() {
198        let result = strcmp_builtin(
199            Value::String("RunMat".into()),
200            Value::String("RunMat".into()),
201        )
202        .expect("strcmp");
203        assert_eq!(result, Value::Bool(true));
204    }
205
206    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
207    #[test]
208    fn strcmp_string_scalar_false() {
209        let result = strcmp_builtin(
210            Value::String("RunMat".into()),
211            Value::String("runmat".into()),
212        )
213        .expect("strcmp");
214        assert_eq!(result, Value::Bool(false));
215    }
216
217    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
218    #[test]
219    fn strcmp_string_array_broadcast_scalar() {
220        let array = StringArray::new(
221            vec!["red".into(), "green".into(), "blue".into()],
222            vec![1, 3],
223        )
224        .unwrap();
225        let result =
226            strcmp_builtin(Value::StringArray(array), Value::String("green".into())).expect("cmp");
227        let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
228        assert_eq!(result, Value::LogicalArray(expected));
229    }
230
231    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
232    #[test]
233    fn strcmp_char_array_row_compare() {
234        let chars = CharArray::new(vec!['c', 'a', 't', 'd', 'o', 'g'], 2, 3).unwrap();
235        let result =
236            strcmp_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("cmp");
237        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
238        assert_eq!(result, Value::LogicalArray(expected));
239    }
240
241    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
242    #[test]
243    fn strcmp_char_array_to_char_array() {
244        let left = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
245        let right = CharArray::new(vec!['a', 'b', 'x', 'y'], 2, 2).unwrap();
246        let result =
247            strcmp_builtin(Value::CharArray(left), Value::CharArray(right)).expect("strcmp");
248        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
249        assert_eq!(result, Value::LogicalArray(expected));
250    }
251
252    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
253    #[test]
254    fn strcmp_cell_array_scalar() {
255        let cell = CellArray::new(
256            vec![
257                Value::from("apple"),
258                Value::from("pear"),
259                Value::from("grape"),
260            ],
261            1,
262            3,
263        )
264        .unwrap();
265        let result =
266            strcmp_builtin(Value::Cell(cell), Value::String("grape".into())).expect("strcmp");
267        let expected = LogicalArray::new(vec![0, 0, 1], vec![1, 3]).unwrap();
268        assert_eq!(result, Value::LogicalArray(expected));
269    }
270
271    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
272    #[test]
273    fn strcmp_cell_array_to_cell_array_broadcasts() {
274        let left = CellArray::new(vec![Value::from("red"), Value::from("blue")], 2, 1).unwrap();
275        let right = CellArray::new(vec![Value::from("red")], 1, 1).unwrap();
276        let result = strcmp_builtin(Value::Cell(left), Value::Cell(right)).expect("strcmp");
277        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
278        assert_eq!(result, Value::LogicalArray(expected));
279    }
280
281    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
282    #[test]
283    fn strcmp_string_array_multi_dimensional_broadcast() {
284        let left = StringArray::new(vec!["north".into(), "south".into()], vec![2, 1]).unwrap();
285        let right = StringArray::new(
286            vec!["north".into(), "east".into(), "south".into()],
287            vec![1, 3],
288        )
289        .unwrap();
290        let result =
291            strcmp_builtin(Value::StringArray(left), Value::StringArray(right)).expect("strcmp");
292        let expected = LogicalArray::new(vec![1, 0, 0, 0, 0, 1], vec![2, 3]).unwrap();
293        assert_eq!(result, Value::LogicalArray(expected));
294    }
295
296    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
297    #[test]
298    fn strcmp_char_array_trailing_space_is_not_equal() {
299        let chars = CharArray::new(vec!['c', 'a', 't', ' '], 1, 4).unwrap();
300        let result =
301            strcmp_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("strcmp");
302        assert_eq!(result, Value::Bool(false));
303    }
304
305    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
306    #[test]
307    fn strcmp_char_array_empty_rows_returns_empty() {
308        let chars = CharArray::new(Vec::new(), 0, 0).unwrap();
309        let result = strcmp_builtin(Value::CharArray(chars), Value::String("anything".into()))
310            .expect("strcmp");
311        match result {
312            Value::LogicalArray(array) => {
313                assert_eq!(array.shape, vec![0, 1]);
314                assert!(array.data.is_empty());
315            }
316            other => panic!("expected empty logical array, got {other:?}"),
317        }
318    }
319
320    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
321    #[test]
322    fn strcmp_missing_strings_compare_false() {
323        let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
324        let result = strcmp_builtin(
325            Value::StringArray(strings.clone()),
326            Value::StringArray(strings),
327        )
328        .expect("strcmp");
329        assert_eq!(result, Value::Bool(false));
330    }
331
332    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
333    #[test]
334    fn strcmp_missing_string_false() {
335        let array = StringArray::new(vec!["alpha".into(), "<missing>".into()], vec![1, 2]).unwrap();
336        let result =
337            strcmp_builtin(Value::StringArray(array), Value::String("alpha".into())).expect("cmp");
338        let expected = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
339        assert_eq!(result, Value::LogicalArray(expected));
340    }
341
342    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
343    #[test]
344    fn strcmp_size_mismatch_error() {
345        let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
346        let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
347        let err = error_message(
348            strcmp_builtin(Value::StringArray(left), Value::StringArray(right))
349                .expect_err("size mismatch"),
350        );
351        assert!(err.contains(STRCMP_ERROR_SHAPE_MISMATCH.message));
352    }
353
354    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
355    #[test]
356    fn strcmp_invalid_argument_type() {
357        let err = error_message(
358            strcmp_builtin(Value::Num(1.0), Value::String("a".into())).expect_err("invalid type"),
359        );
360        assert!(err.contains(STRCMP_ERROR_INVALID_INPUT.message));
361    }
362
363    #[test]
364    fn strcmp_type_is_logical_match() {
365        assert_eq!(
366            logical_text_match_type(
367                &[Type::String, Type::String],
368                &ResolveContext::new(Vec::new()),
369            ),
370            Type::Bool
371        );
372    }
373}