1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::common::tensor;
16use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
17use crate::builtins::strings::type_resolvers::logical_text_match_type;
18use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strcmpi")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "strcmpi",
23 op_kind: GpuOpKind::Custom("string-compare"),
24 supported_precisions: &[],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Runs entirely on the CPU; GPU operands are gathered before comparison.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strcmpi")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "strcmpi",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "Produces logical host results; not eligible for GPU fusion.",
45};
46
47const BUILTIN_NAME: &str = "strcmpi";
48
49const STRCMPI_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50 name: "tf",
51 ty: BuiltinParamType::LogicalArray,
52 arity: BuiltinParamArity::Required,
53 default: None,
54 description: "Logical comparison result.",
55}];
56
57const STRCMPI_INPUTS: [BuiltinParamDescriptor; 2] = [
58 BuiltinParamDescriptor {
59 name: "A",
60 ty: BuiltinParamType::Any,
61 arity: BuiltinParamArity::Required,
62 default: None,
63 description: "First text input (string/char/cell/string array).",
64 },
65 BuiltinParamDescriptor {
66 name: "B",
67 ty: BuiltinParamType::Any,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Second text input (string/char/cell/string array).",
71 },
72];
73
74const STRCMPI_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
75 label: "tf = strcmpi(A, B)",
76 inputs: &STRCMPI_INPUTS,
77 outputs: &STRCMPI_OUTPUT,
78}];
79
80const STRCMPI_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
81 code: "RM.STRCMPI.INVALID_INPUT",
82 identifier: Some("RunMat:strcmpi:InvalidInput"),
83 when: "At least one input is not a supported text container.",
84 message: "strcmpi: text inputs must be string/char/cell/string-array values",
85};
86
87const STRCMPI_ERROR_SHAPE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
88 code: "RM.STRCMPI.SHAPE_MISMATCH",
89 identifier: Some("RunMat:strcmpi:ShapeMismatch"),
90 when: "Inputs are not broadcast-compatible for elementwise comparison.",
91 message: "strcmpi: input sizes are not broadcast-compatible",
92};
93
94const STRCMPI_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
95 code: "RM.STRCMPI.INTERNAL",
96 identifier: Some("RunMat:strcmpi:InternalError"),
97 when: "Internal logical result assembly failed.",
98 message: "strcmpi: internal error",
99};
100
101const STRCMPI_ERRORS: [BuiltinErrorDescriptor; 3] = [
102 STRCMPI_ERROR_INVALID_INPUT,
103 STRCMPI_ERROR_SHAPE_MISMATCH,
104 STRCMPI_ERROR_INTERNAL,
105];
106
107pub const STRCMPI_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
108 signatures: &STRCMPI_SIGNATURES,
109 output_mode: BuiltinOutputMode::Fixed,
110 completion_policy: BuiltinCompletionPolicy::Public,
111 errors: &STRCMPI_ERRORS,
112};
113
114fn strcmpi_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
115 strcmpi_error_with_message(error.message, error)
116}
117
118fn strcmpi_error_with_message(
119 message: impl Into<String>,
120 error: &'static BuiltinErrorDescriptor,
121) -> RuntimeError {
122 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
123 if let Some(identifier) = error.identifier {
124 builder = builder.with_identifier(identifier);
125 }
126 builder.build()
127}
128
129fn remap_strcmpi_flow(err: RuntimeError) -> RuntimeError {
130 map_control_flow_with_builtin(err, BUILTIN_NAME)
131}
132
133#[runtime_builtin(
134 name = "strcmpi",
135 category = "strings/core",
136 summary = "Compare text inputs for case-insensitive equality.",
137 keywords = "strcmpi,string compare,text equality",
138 accel = "sink",
139 type_resolver(logical_text_match_type),
140 descriptor(crate::builtins::strings::core::strcmpi::STRCMPI_DESCRIPTOR),
141 builtin_path = "crate::builtins::strings::core::strcmpi"
142)]
143async fn strcmpi_builtin(a: Value, b: Value) -> crate::BuiltinResult<Value> {
144 let a = gather_if_needed_async(&a)
145 .await
146 .map_err(remap_strcmpi_flow)?;
147 let b = gather_if_needed_async(&b)
148 .await
149 .map_err(remap_strcmpi_flow)?;
150 let left = TextCollection::from_argument(BUILTIN_NAME, a, "first argument")
151 .map_err(|_| strcmpi_error(&STRCMPI_ERROR_INVALID_INPUT))?;
152 let right = TextCollection::from_argument(BUILTIN_NAME, b, "second argument")
153 .map_err(|_| strcmpi_error(&STRCMPI_ERROR_INVALID_INPUT))?;
154 evaluate_strcmpi(&left, &right)
155}
156
157fn evaluate_strcmpi(left: &TextCollection, right: &TextCollection) -> BuiltinResult<Value> {
158 let shape = broadcast_shapes(BUILTIN_NAME, &left.shape, &right.shape)
159 .map_err(|_| strcmpi_error(&STRCMPI_ERROR_SHAPE_MISMATCH))?;
160 let total = tensor::element_count(&shape);
161 if total == 0 {
162 return logical_result(BUILTIN_NAME, Vec::new(), shape)
163 .map_err(|_| strcmpi_error(&STRCMPI_ERROR_INTERNAL));
164 }
165 let left_strides = compute_strides(&left.shape);
166 let right_strides = compute_strides(&right.shape);
167 let left_lower = left.lowercased();
168 let right_lower = right.lowercased();
169 let mut data = Vec::with_capacity(total);
170 for linear in 0..total {
171 let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
172 let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
173 let equal = match (&left.elements[li], &right.elements[ri]) {
174 (TextElement::Missing, _) => false,
175 (_, TextElement::Missing) => false,
176 (TextElement::Text(_), TextElement::Text(_)) => {
177 match (&left_lower[li], &right_lower[ri]) {
178 (Some(lhs), Some(rhs)) => lhs == rhs,
179 _ => false,
180 }
181 }
182 };
183 data.push(if equal { 1 } else { 0 });
184 }
185 logical_result(BUILTIN_NAME, data, shape).map_err(|_| strcmpi_error(&STRCMPI_ERROR_INTERNAL))
186}
187
188#[cfg(test)]
189pub(crate) mod tests {
190 use super::*;
191 use crate::RuntimeError;
192 use runmat_builtins::{CellArray, CharArray, LogicalArray, ResolveContext, StringArray, Type};
193
194 fn strcmpi_builtin(a: Value, b: Value) -> BuiltinResult<Value> {
195 futures::executor::block_on(super::strcmpi_builtin(a, b))
196 }
197
198 fn error_message(err: RuntimeError) -> String {
199 err.to_string()
200 }
201
202 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
203 #[test]
204 fn strcmpi_string_scalar_true_ignores_case() {
205 let result = strcmpi_builtin(
206 Value::String("RunMat".into()),
207 Value::String("runmat".into()),
208 )
209 .expect("strcmpi");
210 assert_eq!(result, Value::Bool(true));
211 }
212
213 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
214 #[test]
215 fn strcmpi_string_scalar_false_when_text_differs() {
216 let result = strcmpi_builtin(
217 Value::String("RunMat".into()),
218 Value::String("runtime".into()),
219 )
220 .expect("strcmpi");
221 assert_eq!(result, Value::Bool(false));
222 }
223
224 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
225 #[test]
226 fn strcmpi_string_array_broadcast_scalar_case_insensitive() {
227 let array = StringArray::new(
228 vec!["red".into(), "green".into(), "blue".into()],
229 vec![1, 3],
230 )
231 .unwrap();
232 let result = strcmpi_builtin(Value::StringArray(array), Value::String("GREEN".into()))
233 .expect("strcmpi");
234 let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
235 assert_eq!(result, Value::LogicalArray(expected));
236 }
237
238 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
239 #[test]
240 fn strcmpi_char_array_row_compare_casefold() {
241 let chars = CharArray::new(vec!['c', 'a', 't', 'D', 'O', 'G'], 2, 3).unwrap();
242 let result =
243 strcmpi_builtin(Value::CharArray(chars), Value::String("CaT".into())).expect("cmp");
244 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
245 assert_eq!(result, Value::LogicalArray(expected));
246 }
247
248 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
249 #[test]
250 fn strcmpi_char_array_to_char_array_casefold() {
251 let left = CharArray::new(vec!['A', 'b', 'C', 'd'], 2, 2).unwrap();
252 let right = CharArray::new(vec!['a', 'B', 'x', 'Y'], 2, 2).unwrap();
253 let result =
254 strcmpi_builtin(Value::CharArray(left), Value::CharArray(right)).expect("strcmpi");
255 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
256 assert_eq!(result, Value::LogicalArray(expected));
257 }
258
259 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
260 #[test]
261 fn strcmpi_cell_array_scalar_casefold() {
262 let cell = CellArray::new(
263 vec![
264 Value::from("North"),
265 Value::from("east"),
266 Value::from("South"),
267 ],
268 1,
269 3,
270 )
271 .unwrap();
272 let result =
273 strcmpi_builtin(Value::Cell(cell), Value::String("EAST".into())).expect("strcmpi");
274 let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
275 assert_eq!(result, Value::LogicalArray(expected));
276 }
277
278 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
279 #[test]
280 fn strcmpi_cell_array_vs_cell_array_broadcast() {
281 let left = CellArray::new(vec![Value::from("North"), Value::from("East")], 1, 2).unwrap();
282 let right = CellArray::new(vec![Value::from("north")], 1, 1).unwrap();
283 let result = strcmpi_builtin(Value::Cell(left), Value::Cell(right)).expect("strcmpi");
284 let expected = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
285 assert_eq!(result, Value::LogicalArray(expected));
286 }
287
288 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
289 #[test]
290 fn strcmpi_string_array_multi_dimensional_broadcast() {
291 let left = StringArray::new(vec!["north".into(), "south".into()], vec![2, 1]).unwrap();
292 let right = StringArray::new(
293 vec!["NORTH".into(), "EAST".into(), "SOUTH".into()],
294 vec![1, 3],
295 )
296 .unwrap();
297 let result =
298 strcmpi_builtin(Value::StringArray(left), Value::StringArray(right)).expect("strcmpi");
299 let expected = LogicalArray::new(vec![1, 0, 0, 0, 0, 1], vec![2, 3]).unwrap();
300 assert_eq!(result, Value::LogicalArray(expected));
301 }
302
303 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
304 #[test]
305 fn strcmpi_missing_strings_compare_false() {
306 let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
307 let result = strcmpi_builtin(
308 Value::StringArray(strings.clone()),
309 Value::StringArray(strings),
310 )
311 .expect("strcmpi");
312 assert_eq!(result, Value::Bool(false));
313 }
314
315 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
316 #[test]
317 fn strcmpi_char_array_trailing_space_not_equal() {
318 let chars = CharArray::new(vec!['c', 'a', 't', ' '], 1, 4).unwrap();
319 let result =
320 strcmpi_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("strcmpi");
321 assert_eq!(result, Value::Bool(false));
322 }
323
324 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
325 #[test]
326 fn strcmpi_size_mismatch_error() {
327 let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
328 let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
329 let err = error_message(
330 strcmpi_builtin(Value::StringArray(left), Value::StringArray(right))
331 .expect_err("size mismatch"),
332 );
333 assert!(err.contains(STRCMPI_ERROR_SHAPE_MISMATCH.message));
334 }
335
336 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
337 #[test]
338 fn strcmpi_invalid_argument_type() {
339 let err = error_message(
340 strcmpi_builtin(Value::Num(1.0), Value::String("a".into())).expect_err("invalid type"),
341 );
342 assert!(err.contains(STRCMPI_ERROR_INVALID_INPUT.message));
343 }
344
345 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
346 #[test]
347 fn strcmpi_cell_array_invalid_element_errors() {
348 let cell = CellArray::new(vec![Value::Num(42.0)], 1, 1).unwrap();
349 let err = error_message(
350 strcmpi_builtin(Value::Cell(cell), Value::String("test".into()))
351 .expect_err("cell element type"),
352 );
353 assert!(err.contains(STRCMPI_ERROR_INVALID_INPUT.message));
354 }
355
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357 #[test]
358 fn strcmpi_empty_char_array_returns_empty() {
359 let chars = CharArray::new(Vec::<char>::new(), 0, 3).unwrap();
360 let result = strcmpi_builtin(Value::CharArray(chars), Value::String("anything".into()))
361 .expect("cmp");
362 let expected = LogicalArray::new(Vec::<u8>::new(), vec![0, 1]).unwrap();
363 assert_eq!(result, Value::LogicalArray(expected));
364 }
365
366 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
367 #[test]
368 #[cfg(feature = "wgpu")]
369 fn strcmpi_with_wgpu_provider_matches_expected() {
370 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
371 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
372 );
373 let names = StringArray::new(vec!["North".into(), "south".into()], vec![2, 1]).unwrap();
374 let comparison = StringArray::new(vec!["north".into()], vec![1, 1]).unwrap();
375 let result = strcmpi_builtin(Value::StringArray(names), Value::StringArray(comparison))
376 .expect("strcmpi");
377 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
378 assert_eq!(result, Value::LogicalArray(expected));
379 }
380
381 #[test]
382 fn strcmpi_type_is_logical_match() {
383 assert_eq!(
384 logical_text_match_type(
385 &[Type::String, Type::String],
386 &ResolveContext::new(Vec::new()),
387 ),
388 Type::Bool
389 );
390 }
391}