1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
7use crate::builtins::common::map_control_flow_with_builtin;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::common::tensor;
13use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
14use crate::builtins::strings::type_resolvers::logical_text_match_type;
15use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strcmpi")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19 name: "strcmpi",
20 op_kind: GpuOpKind::Custom("string-compare"),
21 supported_precisions: &[],
22 broadcast: BroadcastSemantics::Matlab,
23 provider_hooks: &[],
24 constant_strategy: ConstantStrategy::InlineLiteral,
25 residency: ResidencyPolicy::GatherImmediately,
26 nan_mode: ReductionNaN::Include,
27 two_pass_threshold: None,
28 workgroup_size: None,
29 accepts_nan_mode: false,
30 notes: "Runs entirely on the CPU; GPU operands are gathered before comparison.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strcmpi")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "strcmpi",
36 shape: ShapeRequirements::Any,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes: "Produces logical host results; not eligible for GPU fusion.",
42};
43
44#[allow(dead_code)]
45fn strcmpi_flow(message: impl Into<String>) -> RuntimeError {
46 build_runtime_error(message).with_builtin("strcmpi").build()
47}
48
49fn remap_strcmpi_flow(err: RuntimeError) -> RuntimeError {
50 map_control_flow_with_builtin(err, "strcmpi")
51}
52
53#[runtime_builtin(
54 name = "strcmpi",
55 category = "strings/core",
56 summary = "Compare text inputs for equality without considering case.",
57 keywords = "strcmpi,string compare,text equality",
58 accel = "sink",
59 type_resolver(logical_text_match_type),
60 builtin_path = "crate::builtins::strings::core::strcmpi"
61)]
62async fn strcmpi_builtin(a: Value, b: Value) -> crate::BuiltinResult<Value> {
63 let a = gather_if_needed_async(&a)
64 .await
65 .map_err(remap_strcmpi_flow)?;
66 let b = gather_if_needed_async(&b)
67 .await
68 .map_err(remap_strcmpi_flow)?;
69 let left = TextCollection::from_argument("strcmpi", a, "first argument")?;
70 let right = TextCollection::from_argument("strcmpi", b, "second argument")?;
71 evaluate_strcmpi(&left, &right)
72}
73
74fn evaluate_strcmpi(left: &TextCollection, right: &TextCollection) -> BuiltinResult<Value> {
75 let shape = broadcast_shapes("strcmpi", &left.shape, &right.shape)?;
76 let total = tensor::element_count(&shape);
77 if total == 0 {
78 return logical_result("strcmpi", Vec::new(), shape);
79 }
80 let left_strides = compute_strides(&left.shape);
81 let right_strides = compute_strides(&right.shape);
82 let left_lower = left.lowercased();
83 let right_lower = right.lowercased();
84 let mut data = Vec::with_capacity(total);
85 for linear in 0..total {
86 let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
87 let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
88 let equal = match (&left.elements[li], &right.elements[ri]) {
89 (TextElement::Missing, _) => false,
90 (_, TextElement::Missing) => false,
91 (TextElement::Text(_), TextElement::Text(_)) => {
92 match (&left_lower[li], &right_lower[ri]) {
93 (Some(lhs), Some(rhs)) => lhs == rhs,
94 _ => false,
95 }
96 }
97 };
98 data.push(if equal { 1 } else { 0 });
99 }
100 logical_result("strcmpi", data, shape)
101}
102
103#[cfg(test)]
104pub(crate) mod tests {
105 use super::*;
106 use crate::RuntimeError;
107 use runmat_builtins::{CellArray, CharArray, LogicalArray, ResolveContext, StringArray, Type};
108
109 fn strcmpi_builtin(a: Value, b: Value) -> BuiltinResult<Value> {
110 futures::executor::block_on(super::strcmpi_builtin(a, b))
111 }
112
113 fn error_message(err: RuntimeError) -> String {
114 err.message().to_string()
115 }
116
117 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
118 #[test]
119 fn strcmpi_string_scalar_true_ignores_case() {
120 let result = strcmpi_builtin(
121 Value::String("RunMat".into()),
122 Value::String("runmat".into()),
123 )
124 .expect("strcmpi");
125 assert_eq!(result, Value::Bool(true));
126 }
127
128 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
129 #[test]
130 fn strcmpi_string_scalar_false_when_text_differs() {
131 let result = strcmpi_builtin(
132 Value::String("RunMat".into()),
133 Value::String("runtime".into()),
134 )
135 .expect("strcmpi");
136 assert_eq!(result, Value::Bool(false));
137 }
138
139 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
140 #[test]
141 fn strcmpi_string_array_broadcast_scalar_case_insensitive() {
142 let array = StringArray::new(
143 vec!["red".into(), "green".into(), "blue".into()],
144 vec![1, 3],
145 )
146 .unwrap();
147 let result = strcmpi_builtin(Value::StringArray(array), Value::String("GREEN".into()))
148 .expect("strcmpi");
149 let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
150 assert_eq!(result, Value::LogicalArray(expected));
151 }
152
153 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
154 #[test]
155 fn strcmpi_char_array_row_compare_casefold() {
156 let chars = CharArray::new(vec!['c', 'a', 't', 'D', 'O', 'G'], 2, 3).unwrap();
157 let result =
158 strcmpi_builtin(Value::CharArray(chars), Value::String("CaT".into())).expect("cmp");
159 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
160 assert_eq!(result, Value::LogicalArray(expected));
161 }
162
163 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
164 #[test]
165 fn strcmpi_char_array_to_char_array_casefold() {
166 let left = CharArray::new(vec!['A', 'b', 'C', 'd'], 2, 2).unwrap();
167 let right = CharArray::new(vec!['a', 'B', 'x', 'Y'], 2, 2).unwrap();
168 let result =
169 strcmpi_builtin(Value::CharArray(left), Value::CharArray(right)).expect("strcmpi");
170 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
171 assert_eq!(result, Value::LogicalArray(expected));
172 }
173
174 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
175 #[test]
176 fn strcmpi_cell_array_scalar_casefold() {
177 let cell = CellArray::new(
178 vec![
179 Value::from("North"),
180 Value::from("east"),
181 Value::from("South"),
182 ],
183 1,
184 3,
185 )
186 .unwrap();
187 let result =
188 strcmpi_builtin(Value::Cell(cell), Value::String("EAST".into())).expect("strcmpi");
189 let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
190 assert_eq!(result, Value::LogicalArray(expected));
191 }
192
193 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
194 #[test]
195 fn strcmpi_cell_array_vs_cell_array_broadcast() {
196 let left = CellArray::new(vec![Value::from("North"), Value::from("East")], 1, 2).unwrap();
197 let right = CellArray::new(vec![Value::from("north")], 1, 1).unwrap();
198 let result = strcmpi_builtin(Value::Cell(left), Value::Cell(right)).expect("strcmpi");
199 let expected = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
200 assert_eq!(result, Value::LogicalArray(expected));
201 }
202
203 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
204 #[test]
205 fn strcmpi_string_array_multi_dimensional_broadcast() {
206 let left = StringArray::new(vec!["north".into(), "south".into()], vec![2, 1]).unwrap();
207 let right = StringArray::new(
208 vec!["NORTH".into(), "EAST".into(), "SOUTH".into()],
209 vec![1, 3],
210 )
211 .unwrap();
212 let result =
213 strcmpi_builtin(Value::StringArray(left), Value::StringArray(right)).expect("strcmpi");
214 let expected = LogicalArray::new(vec![1, 0, 0, 0, 0, 1], vec![2, 3]).unwrap();
215 assert_eq!(result, Value::LogicalArray(expected));
216 }
217
218 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
219 #[test]
220 fn strcmpi_missing_strings_compare_false() {
221 let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
222 let result = strcmpi_builtin(
223 Value::StringArray(strings.clone()),
224 Value::StringArray(strings),
225 )
226 .expect("strcmpi");
227 assert_eq!(result, Value::Bool(false));
228 }
229
230 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
231 #[test]
232 fn strcmpi_char_array_trailing_space_not_equal() {
233 let chars = CharArray::new(vec!['c', 'a', 't', ' '], 1, 4).unwrap();
234 let result =
235 strcmpi_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("strcmpi");
236 assert_eq!(result, Value::Bool(false));
237 }
238
239 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
240 #[test]
241 fn strcmpi_size_mismatch_error() {
242 let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
243 let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
244 let err = error_message(
245 strcmpi_builtin(Value::StringArray(left), Value::StringArray(right))
246 .expect_err("size mismatch"),
247 );
248 assert!(err.contains("size mismatch"));
249 }
250
251 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
252 #[test]
253 fn strcmpi_invalid_argument_type() {
254 let err = error_message(
255 strcmpi_builtin(Value::Num(1.0), Value::String("a".into())).expect_err("invalid type"),
256 );
257 assert!(err.contains("first argument must be text"));
258 }
259
260 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
261 #[test]
262 fn strcmpi_cell_array_invalid_element_errors() {
263 let cell = CellArray::new(vec![Value::Num(42.0)], 1, 1).unwrap();
264 let err = error_message(
265 strcmpi_builtin(Value::Cell(cell), Value::String("test".into()))
266 .expect_err("cell element type"),
267 );
268 assert!(err.contains("cell array elements must be character vectors or string scalars"));
269 }
270
271 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
272 #[test]
273 fn strcmpi_empty_char_array_returns_empty() {
274 let chars = CharArray::new(Vec::<char>::new(), 0, 3).unwrap();
275 let result = strcmpi_builtin(Value::CharArray(chars), Value::String("anything".into()))
276 .expect("cmp");
277 let expected = LogicalArray::new(Vec::<u8>::new(), vec![0, 1]).unwrap();
278 assert_eq!(result, Value::LogicalArray(expected));
279 }
280
281 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
282 #[test]
283 #[cfg(feature = "wgpu")]
284 fn strcmpi_with_wgpu_provider_matches_expected() {
285 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
286 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
287 );
288 let names = StringArray::new(vec!["North".into(), "south".into()], vec![2, 1]).unwrap();
289 let comparison = StringArray::new(vec!["north".into()], vec![1, 1]).unwrap();
290 let result = strcmpi_builtin(Value::StringArray(names), Value::StringArray(comparison))
291 .expect("strcmpi");
292 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
293 assert_eq!(result, Value::LogicalArray(expected));
294 }
295
296 #[test]
297 fn strcmpi_type_is_logical_match() {
298 assert_eq!(
299 logical_text_match_type(
300 &[Type::String, Type::String],
301 &ResolveContext::new(Vec::new()),
302 ),
303 Type::Bool
304 );
305 }
306}