Skip to main content

runmat_runtime/builtins/strings/core/
strcmpi.rs

1//! MATLAB-compatible `strcmpi` builtin for RunMat (case-insensitive string comparison).
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::common::tensor;
16use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
17use crate::builtins::strings::type_resolvers::logical_text_match_type;
18use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strcmpi")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22    name: "strcmpi",
23    op_kind: GpuOpKind::Custom("string-compare"),
24    supported_precisions: &[],
25    broadcast: BroadcastSemantics::Matlab,
26    provider_hooks: &[],
27    constant_strategy: ConstantStrategy::InlineLiteral,
28    residency: ResidencyPolicy::GatherImmediately,
29    nan_mode: ReductionNaN::Include,
30    two_pass_threshold: None,
31    workgroup_size: None,
32    accepts_nan_mode: false,
33    notes: "Runs entirely on the CPU; GPU operands are gathered before comparison.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strcmpi")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38    name: "strcmpi",
39    shape: ShapeRequirements::Any,
40    constant_strategy: ConstantStrategy::InlineLiteral,
41    elementwise: None,
42    reduction: None,
43    emits_nan: false,
44    notes: "Produces logical host results; not eligible for GPU fusion.",
45};
46
47const BUILTIN_NAME: &str = "strcmpi";
48
49const STRCMPI_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50    name: "tf",
51    ty: BuiltinParamType::LogicalArray,
52    arity: BuiltinParamArity::Required,
53    default: None,
54    description: "Logical comparison result.",
55}];
56
57const STRCMPI_INPUTS: [BuiltinParamDescriptor; 2] = [
58    BuiltinParamDescriptor {
59        name: "A",
60        ty: BuiltinParamType::Any,
61        arity: BuiltinParamArity::Required,
62        default: None,
63        description: "First text input (string/char/cell/string array).",
64    },
65    BuiltinParamDescriptor {
66        name: "B",
67        ty: BuiltinParamType::Any,
68        arity: BuiltinParamArity::Required,
69        default: None,
70        description: "Second text input (string/char/cell/string array).",
71    },
72];
73
74const STRCMPI_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
75    label: "tf = strcmpi(A, B)",
76    inputs: &STRCMPI_INPUTS,
77    outputs: &STRCMPI_OUTPUT,
78}];
79
80const STRCMPI_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
81    code: "RM.STRCMPI.INVALID_INPUT",
82    identifier: Some("RunMat:strcmpi:InvalidInput"),
83    when: "At least one input is not a supported text container.",
84    message: "strcmpi: text inputs must be string/char/cell/string-array values",
85};
86
87const STRCMPI_ERROR_SHAPE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
88    code: "RM.STRCMPI.SHAPE_MISMATCH",
89    identifier: Some("RunMat:strcmpi:ShapeMismatch"),
90    when: "Inputs are not broadcast-compatible for elementwise comparison.",
91    message: "strcmpi: input sizes are not broadcast-compatible",
92};
93
94const STRCMPI_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
95    code: "RM.STRCMPI.INTERNAL",
96    identifier: Some("RunMat:strcmpi:InternalError"),
97    when: "Internal logical result assembly failed.",
98    message: "strcmpi: internal error",
99};
100
101const STRCMPI_ERRORS: [BuiltinErrorDescriptor; 3] = [
102    STRCMPI_ERROR_INVALID_INPUT,
103    STRCMPI_ERROR_SHAPE_MISMATCH,
104    STRCMPI_ERROR_INTERNAL,
105];
106
107pub const STRCMPI_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
108    signatures: &STRCMPI_SIGNATURES,
109    output_mode: BuiltinOutputMode::Fixed,
110    completion_policy: BuiltinCompletionPolicy::Public,
111    errors: &STRCMPI_ERRORS,
112};
113
114fn strcmpi_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
115    strcmpi_error_with_message(error.message, error)
116}
117
118fn strcmpi_error_with_message(
119    message: impl Into<String>,
120    error: &'static BuiltinErrorDescriptor,
121) -> RuntimeError {
122    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
123    if let Some(identifier) = error.identifier {
124        builder = builder.with_identifier(identifier);
125    }
126    builder.build()
127}
128
129fn remap_strcmpi_flow(err: RuntimeError) -> RuntimeError {
130    map_control_flow_with_builtin(err, BUILTIN_NAME)
131}
132
133#[runtime_builtin(
134    name = "strcmpi",
135    category = "strings/core",
136    summary = "Compare text inputs for case-insensitive equality.",
137    keywords = "strcmpi,string compare,text equality",
138    accel = "sink",
139    type_resolver(logical_text_match_type),
140    descriptor(crate::builtins::strings::core::strcmpi::STRCMPI_DESCRIPTOR),
141    builtin_path = "crate::builtins::strings::core::strcmpi"
142)]
143async fn strcmpi_builtin(a: Value, b: Value) -> crate::BuiltinResult<Value> {
144    let a = gather_if_needed_async(&a)
145        .await
146        .map_err(remap_strcmpi_flow)?;
147    let b = gather_if_needed_async(&b)
148        .await
149        .map_err(remap_strcmpi_flow)?;
150    let left = TextCollection::from_argument(BUILTIN_NAME, a, "first argument")
151        .map_err(|_| strcmpi_error(&STRCMPI_ERROR_INVALID_INPUT))?;
152    let right = TextCollection::from_argument(BUILTIN_NAME, b, "second argument")
153        .map_err(|_| strcmpi_error(&STRCMPI_ERROR_INVALID_INPUT))?;
154    evaluate_strcmpi(&left, &right)
155}
156
157fn evaluate_strcmpi(left: &TextCollection, right: &TextCollection) -> BuiltinResult<Value> {
158    let shape = broadcast_shapes(BUILTIN_NAME, &left.shape, &right.shape)
159        .map_err(|_| strcmpi_error(&STRCMPI_ERROR_SHAPE_MISMATCH))?;
160    let total = tensor::element_count(&shape);
161    if total == 0 {
162        return logical_result(BUILTIN_NAME, Vec::new(), shape)
163            .map_err(|_| strcmpi_error(&STRCMPI_ERROR_INTERNAL));
164    }
165    let left_strides = compute_strides(&left.shape);
166    let right_strides = compute_strides(&right.shape);
167    let left_lower = left.lowercased();
168    let right_lower = right.lowercased();
169    let mut data = Vec::with_capacity(total);
170    for linear in 0..total {
171        let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
172        let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
173        let equal = match (&left.elements[li], &right.elements[ri]) {
174            (TextElement::Missing, _) => false,
175            (_, TextElement::Missing) => false,
176            (TextElement::Text(_), TextElement::Text(_)) => {
177                match (&left_lower[li], &right_lower[ri]) {
178                    (Some(lhs), Some(rhs)) => lhs == rhs,
179                    _ => false,
180                }
181            }
182        };
183        data.push(if equal { 1 } else { 0 });
184    }
185    logical_result(BUILTIN_NAME, data, shape).map_err(|_| strcmpi_error(&STRCMPI_ERROR_INTERNAL))
186}
187
188#[cfg(test)]
189pub(crate) mod tests {
190    use super::*;
191    use crate::RuntimeError;
192    use runmat_builtins::{CellArray, CharArray, LogicalArray, ResolveContext, StringArray, Type};
193
194    fn strcmpi_builtin(a: Value, b: Value) -> BuiltinResult<Value> {
195        futures::executor::block_on(super::strcmpi_builtin(a, b))
196    }
197
198    fn error_message(err: RuntimeError) -> String {
199        err.to_string()
200    }
201
202    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
203    #[test]
204    fn strcmpi_string_scalar_true_ignores_case() {
205        let result = strcmpi_builtin(
206            Value::String("RunMat".into()),
207            Value::String("runmat".into()),
208        )
209        .expect("strcmpi");
210        assert_eq!(result, Value::Bool(true));
211    }
212
213    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
214    #[test]
215    fn strcmpi_string_scalar_false_when_text_differs() {
216        let result = strcmpi_builtin(
217            Value::String("RunMat".into()),
218            Value::String("runtime".into()),
219        )
220        .expect("strcmpi");
221        assert_eq!(result, Value::Bool(false));
222    }
223
224    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
225    #[test]
226    fn strcmpi_string_array_broadcast_scalar_case_insensitive() {
227        let array = StringArray::new(
228            vec!["red".into(), "green".into(), "blue".into()],
229            vec![1, 3],
230        )
231        .unwrap();
232        let result = strcmpi_builtin(Value::StringArray(array), Value::String("GREEN".into()))
233            .expect("strcmpi");
234        let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
235        assert_eq!(result, Value::LogicalArray(expected));
236    }
237
238    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
239    #[test]
240    fn strcmpi_char_array_row_compare_casefold() {
241        let chars = CharArray::new(vec!['c', 'a', 't', 'D', 'O', 'G'], 2, 3).unwrap();
242        let result =
243            strcmpi_builtin(Value::CharArray(chars), Value::String("CaT".into())).expect("cmp");
244        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
245        assert_eq!(result, Value::LogicalArray(expected));
246    }
247
248    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
249    #[test]
250    fn strcmpi_char_array_to_char_array_casefold() {
251        let left = CharArray::new(vec!['A', 'b', 'C', 'd'], 2, 2).unwrap();
252        let right = CharArray::new(vec!['a', 'B', 'x', 'Y'], 2, 2).unwrap();
253        let result =
254            strcmpi_builtin(Value::CharArray(left), Value::CharArray(right)).expect("strcmpi");
255        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
256        assert_eq!(result, Value::LogicalArray(expected));
257    }
258
259    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
260    #[test]
261    fn strcmpi_cell_array_scalar_casefold() {
262        let cell = CellArray::new(
263            vec![
264                Value::from("North"),
265                Value::from("east"),
266                Value::from("South"),
267            ],
268            1,
269            3,
270        )
271        .unwrap();
272        let result =
273            strcmpi_builtin(Value::Cell(cell), Value::String("EAST".into())).expect("strcmpi");
274        let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
275        assert_eq!(result, Value::LogicalArray(expected));
276    }
277
278    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
279    #[test]
280    fn strcmpi_cell_array_vs_cell_array_broadcast() {
281        let left = CellArray::new(vec![Value::from("North"), Value::from("East")], 1, 2).unwrap();
282        let right = CellArray::new(vec![Value::from("north")], 1, 1).unwrap();
283        let result = strcmpi_builtin(Value::Cell(left), Value::Cell(right)).expect("strcmpi");
284        let expected = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
285        assert_eq!(result, Value::LogicalArray(expected));
286    }
287
288    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
289    #[test]
290    fn strcmpi_string_array_multi_dimensional_broadcast() {
291        let left = StringArray::new(vec!["north".into(), "south".into()], vec![2, 1]).unwrap();
292        let right = StringArray::new(
293            vec!["NORTH".into(), "EAST".into(), "SOUTH".into()],
294            vec![1, 3],
295        )
296        .unwrap();
297        let result =
298            strcmpi_builtin(Value::StringArray(left), Value::StringArray(right)).expect("strcmpi");
299        let expected = LogicalArray::new(vec![1, 0, 0, 0, 0, 1], vec![2, 3]).unwrap();
300        assert_eq!(result, Value::LogicalArray(expected));
301    }
302
303    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304    #[test]
305    fn strcmpi_missing_strings_compare_false() {
306        let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
307        let result = strcmpi_builtin(
308            Value::StringArray(strings.clone()),
309            Value::StringArray(strings),
310        )
311        .expect("strcmpi");
312        assert_eq!(result, Value::Bool(false));
313    }
314
315    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
316    #[test]
317    fn strcmpi_char_array_trailing_space_not_equal() {
318        let chars = CharArray::new(vec!['c', 'a', 't', ' '], 1, 4).unwrap();
319        let result =
320            strcmpi_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("strcmpi");
321        assert_eq!(result, Value::Bool(false));
322    }
323
324    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
325    #[test]
326    fn strcmpi_size_mismatch_error() {
327        let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
328        let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
329        let err = error_message(
330            strcmpi_builtin(Value::StringArray(left), Value::StringArray(right))
331                .expect_err("size mismatch"),
332        );
333        assert!(err.contains(STRCMPI_ERROR_SHAPE_MISMATCH.message));
334    }
335
336    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
337    #[test]
338    fn strcmpi_invalid_argument_type() {
339        let err = error_message(
340            strcmpi_builtin(Value::Num(1.0), Value::String("a".into())).expect_err("invalid type"),
341        );
342        assert!(err.contains(STRCMPI_ERROR_INVALID_INPUT.message));
343    }
344
345    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
346    #[test]
347    fn strcmpi_cell_array_invalid_element_errors() {
348        let cell = CellArray::new(vec![Value::Num(42.0)], 1, 1).unwrap();
349        let err = error_message(
350            strcmpi_builtin(Value::Cell(cell), Value::String("test".into()))
351                .expect_err("cell element type"),
352        );
353        assert!(err.contains(STRCMPI_ERROR_INVALID_INPUT.message));
354    }
355
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357    #[test]
358    fn strcmpi_empty_char_array_returns_empty() {
359        let chars = CharArray::new(Vec::<char>::new(), 0, 3).unwrap();
360        let result = strcmpi_builtin(Value::CharArray(chars), Value::String("anything".into()))
361            .expect("cmp");
362        let expected = LogicalArray::new(Vec::<u8>::new(), vec![0, 1]).unwrap();
363        assert_eq!(result, Value::LogicalArray(expected));
364    }
365
366    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
367    #[test]
368    #[cfg(feature = "wgpu")]
369    fn strcmpi_with_wgpu_provider_matches_expected() {
370        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
371            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
372        );
373        let names = StringArray::new(vec!["North".into(), "south".into()], vec![2, 1]).unwrap();
374        let comparison = StringArray::new(vec!["north".into()], vec![1, 1]).unwrap();
375        let result = strcmpi_builtin(Value::StringArray(names), Value::StringArray(comparison))
376            .expect("strcmpi");
377        let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
378        assert_eq!(result, Value::LogicalArray(expected));
379    }
380
381    #[test]
382    fn strcmpi_type_is_logical_match() {
383        assert_eq!(
384            logical_text_match_type(
385                &[Type::String, Type::String],
386                &ResolveContext::new(Vec::new()),
387            ),
388            Type::Bool
389        );
390    }
391}