1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
10use crate::builtins::common::map_control_flow_with_builtin;
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::common::tensor;
16use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
17use crate::builtins::strings::type_resolvers::logical_text_match_type;
18use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
19
20#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strcmp")]
21pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
22 name: "strcmp",
23 op_kind: GpuOpKind::Custom("string-compare"),
24 supported_precisions: &[],
25 broadcast: BroadcastSemantics::Matlab,
26 provider_hooks: &[],
27 constant_strategy: ConstantStrategy::InlineLiteral,
28 residency: ResidencyPolicy::GatherImmediately,
29 nan_mode: ReductionNaN::Include,
30 two_pass_threshold: None,
31 workgroup_size: None,
32 accepts_nan_mode: false,
33 notes: "Performs host-side text comparisons; GPU operands are gathered automatically before evaluation.",
34};
35
36#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strcmp")]
37pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
38 name: "strcmp",
39 shape: ShapeRequirements::Any,
40 constant_strategy: ConstantStrategy::InlineLiteral,
41 elementwise: None,
42 reduction: None,
43 emits_nan: false,
44 notes: "Produces logical results on the host; not eligible for GPU fusion.",
45};
46
47const BUILTIN_NAME: &str = "strcmp";
48
49const STRCMP_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
50 name: "tf",
51 ty: BuiltinParamType::LogicalArray,
52 arity: BuiltinParamArity::Required,
53 default: None,
54 description: "Logical comparison result.",
55}];
56
57const STRCMP_INPUTS: [BuiltinParamDescriptor; 2] = [
58 BuiltinParamDescriptor {
59 name: "A",
60 ty: BuiltinParamType::Any,
61 arity: BuiltinParamArity::Required,
62 default: None,
63 description: "First text input (string/char/cell/string array).",
64 },
65 BuiltinParamDescriptor {
66 name: "B",
67 ty: BuiltinParamType::Any,
68 arity: BuiltinParamArity::Required,
69 default: None,
70 description: "Second text input (string/char/cell/string array).",
71 },
72];
73
74const STRCMP_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
75 label: "tf = strcmp(A, B)",
76 inputs: &STRCMP_INPUTS,
77 outputs: &STRCMP_OUTPUT,
78}];
79
80const STRCMP_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
81 code: "RM.STRCMP.INVALID_INPUT",
82 identifier: Some("RunMat:strcmp:InvalidInput"),
83 when: "At least one input is not a supported text container.",
84 message: "strcmp: text inputs must be string/char/cell/string-array values",
85};
86
87const STRCMP_ERROR_SHAPE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
88 code: "RM.STRCMP.SHAPE_MISMATCH",
89 identifier: Some("RunMat:strcmp:ShapeMismatch"),
90 when: "Inputs are not broadcast-compatible for elementwise comparison.",
91 message: "strcmp: input sizes are not broadcast-compatible",
92};
93
94const STRCMP_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
95 code: "RM.STRCMP.INTERNAL",
96 identifier: Some("RunMat:strcmp:InternalError"),
97 when: "Internal logical result assembly failed.",
98 message: "strcmp: internal error",
99};
100
101const STRCMP_ERRORS: [BuiltinErrorDescriptor; 3] = [
102 STRCMP_ERROR_INVALID_INPUT,
103 STRCMP_ERROR_SHAPE_MISMATCH,
104 STRCMP_ERROR_INTERNAL,
105];
106
107pub const STRCMP_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
108 signatures: &STRCMP_SIGNATURES,
109 output_mode: BuiltinOutputMode::Fixed,
110 completion_policy: BuiltinCompletionPolicy::Public,
111 errors: &STRCMP_ERRORS,
112};
113
114fn strcmp_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
115 strcmp_error_with_message(error.message, error)
116}
117
118fn strcmp_error_with_message(
119 message: impl Into<String>,
120 error: &'static BuiltinErrorDescriptor,
121) -> RuntimeError {
122 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
123 if let Some(identifier) = error.identifier {
124 builder = builder.with_identifier(identifier);
125 }
126 builder.build()
127}
128
129fn remap_strcmp_flow(err: RuntimeError) -> RuntimeError {
130 map_control_flow_with_builtin(err, BUILTIN_NAME)
131}
132
133#[runtime_builtin(
134 name = "strcmp",
135 category = "strings/core",
136 summary = "Compare text inputs for exact case-sensitive equality.",
137 keywords = "strcmp,string compare,text equality",
138 accel = "sink",
139 type_resolver(logical_text_match_type),
140 descriptor(crate::builtins::strings::core::strcmp::STRCMP_DESCRIPTOR),
141 builtin_path = "crate::builtins::strings::core::strcmp"
142)]
143async fn strcmp_builtin(a: Value, b: Value) -> crate::BuiltinResult<Value> {
144 let a = gather_if_needed_async(&a)
145 .await
146 .map_err(remap_strcmp_flow)?;
147 let b = gather_if_needed_async(&b)
148 .await
149 .map_err(remap_strcmp_flow)?;
150 let left = TextCollection::from_argument(BUILTIN_NAME, a, "first argument")
151 .map_err(|_| strcmp_error(&STRCMP_ERROR_INVALID_INPUT))?;
152 let right = TextCollection::from_argument(BUILTIN_NAME, b, "second argument")
153 .map_err(|_| strcmp_error(&STRCMP_ERROR_INVALID_INPUT))?;
154 evaluate_strcmp(&left, &right)
155}
156
157fn evaluate_strcmp(left: &TextCollection, right: &TextCollection) -> BuiltinResult<Value> {
158 let shape = broadcast_shapes(BUILTIN_NAME, &left.shape, &right.shape)
159 .map_err(|_| strcmp_error(&STRCMP_ERROR_SHAPE_MISMATCH))?;
160 let total = tensor::element_count(&shape);
161 if total == 0 {
162 return logical_result(BUILTIN_NAME, Vec::new(), shape)
163 .map_err(|_| strcmp_error(&STRCMP_ERROR_INTERNAL));
164 }
165 let left_strides = compute_strides(&left.shape);
166 let right_strides = compute_strides(&right.shape);
167 let mut data = Vec::with_capacity(total);
168 for linear in 0..total {
169 let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
170 let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
171 let equal = match (&left.elements[li], &right.elements[ri]) {
172 (TextElement::Missing, _) => false,
173 (_, TextElement::Missing) => false,
174 (TextElement::Text(lhs), TextElement::Text(rhs)) => lhs == rhs,
175 };
176 data.push(if equal { 1 } else { 0 });
177 }
178 logical_result(BUILTIN_NAME, data, shape).map_err(|_| strcmp_error(&STRCMP_ERROR_INTERNAL))
179}
180
181#[cfg(test)]
182pub(crate) mod tests {
183 use super::*;
184 use crate::RuntimeError;
185 use runmat_builtins::{CellArray, CharArray, LogicalArray, ResolveContext, StringArray, Type};
186
187 fn strcmp_builtin(a: Value, b: Value) -> BuiltinResult<Value> {
188 futures::executor::block_on(super::strcmp_builtin(a, b))
189 }
190
191 fn error_message(err: RuntimeError) -> String {
192 err.to_string()
193 }
194
195 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
196 #[test]
197 fn strcmp_string_scalar_true() {
198 let result = strcmp_builtin(
199 Value::String("RunMat".into()),
200 Value::String("RunMat".into()),
201 )
202 .expect("strcmp");
203 assert_eq!(result, Value::Bool(true));
204 }
205
206 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
207 #[test]
208 fn strcmp_string_scalar_false() {
209 let result = strcmp_builtin(
210 Value::String("RunMat".into()),
211 Value::String("runmat".into()),
212 )
213 .expect("strcmp");
214 assert_eq!(result, Value::Bool(false));
215 }
216
217 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
218 #[test]
219 fn strcmp_string_array_broadcast_scalar() {
220 let array = StringArray::new(
221 vec!["red".into(), "green".into(), "blue".into()],
222 vec![1, 3],
223 )
224 .unwrap();
225 let result =
226 strcmp_builtin(Value::StringArray(array), Value::String("green".into())).expect("cmp");
227 let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
228 assert_eq!(result, Value::LogicalArray(expected));
229 }
230
231 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
232 #[test]
233 fn strcmp_char_array_row_compare() {
234 let chars = CharArray::new(vec!['c', 'a', 't', 'd', 'o', 'g'], 2, 3).unwrap();
235 let result =
236 strcmp_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("cmp");
237 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
238 assert_eq!(result, Value::LogicalArray(expected));
239 }
240
241 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
242 #[test]
243 fn strcmp_char_array_to_char_array() {
244 let left = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
245 let right = CharArray::new(vec!['a', 'b', 'x', 'y'], 2, 2).unwrap();
246 let result =
247 strcmp_builtin(Value::CharArray(left), Value::CharArray(right)).expect("strcmp");
248 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
249 assert_eq!(result, Value::LogicalArray(expected));
250 }
251
252 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
253 #[test]
254 fn strcmp_cell_array_scalar() {
255 let cell = CellArray::new(
256 vec![
257 Value::from("apple"),
258 Value::from("pear"),
259 Value::from("grape"),
260 ],
261 1,
262 3,
263 )
264 .unwrap();
265 let result =
266 strcmp_builtin(Value::Cell(cell), Value::String("grape".into())).expect("strcmp");
267 let expected = LogicalArray::new(vec![0, 0, 1], vec![1, 3]).unwrap();
268 assert_eq!(result, Value::LogicalArray(expected));
269 }
270
271 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
272 #[test]
273 fn strcmp_cell_array_to_cell_array_broadcasts() {
274 let left = CellArray::new(vec![Value::from("red"), Value::from("blue")], 2, 1).unwrap();
275 let right = CellArray::new(vec![Value::from("red")], 1, 1).unwrap();
276 let result = strcmp_builtin(Value::Cell(left), Value::Cell(right)).expect("strcmp");
277 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
278 assert_eq!(result, Value::LogicalArray(expected));
279 }
280
281 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
282 #[test]
283 fn strcmp_string_array_multi_dimensional_broadcast() {
284 let left = StringArray::new(vec!["north".into(), "south".into()], vec![2, 1]).unwrap();
285 let right = StringArray::new(
286 vec!["north".into(), "east".into(), "south".into()],
287 vec![1, 3],
288 )
289 .unwrap();
290 let result =
291 strcmp_builtin(Value::StringArray(left), Value::StringArray(right)).expect("strcmp");
292 let expected = LogicalArray::new(vec![1, 0, 0, 0, 0, 1], vec![2, 3]).unwrap();
293 assert_eq!(result, Value::LogicalArray(expected));
294 }
295
296 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
297 #[test]
298 fn strcmp_char_array_trailing_space_is_not_equal() {
299 let chars = CharArray::new(vec!['c', 'a', 't', ' '], 1, 4).unwrap();
300 let result =
301 strcmp_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("strcmp");
302 assert_eq!(result, Value::Bool(false));
303 }
304
305 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
306 #[test]
307 fn strcmp_char_array_empty_rows_returns_empty() {
308 let chars = CharArray::new(Vec::new(), 0, 0).unwrap();
309 let result = strcmp_builtin(Value::CharArray(chars), Value::String("anything".into()))
310 .expect("strcmp");
311 match result {
312 Value::LogicalArray(array) => {
313 assert_eq!(array.shape, vec![0, 1]);
314 assert!(array.data.is_empty());
315 }
316 other => panic!("expected empty logical array, got {other:?}"),
317 }
318 }
319
320 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
321 #[test]
322 fn strcmp_missing_strings_compare_false() {
323 let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
324 let result = strcmp_builtin(
325 Value::StringArray(strings.clone()),
326 Value::StringArray(strings),
327 )
328 .expect("strcmp");
329 assert_eq!(result, Value::Bool(false));
330 }
331
332 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
333 #[test]
334 fn strcmp_missing_string_false() {
335 let array = StringArray::new(vec!["alpha".into(), "<missing>".into()], vec![1, 2]).unwrap();
336 let result =
337 strcmp_builtin(Value::StringArray(array), Value::String("alpha".into())).expect("cmp");
338 let expected = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
339 assert_eq!(result, Value::LogicalArray(expected));
340 }
341
342 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
343 #[test]
344 fn strcmp_size_mismatch_error() {
345 let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
346 let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
347 let err = error_message(
348 strcmp_builtin(Value::StringArray(left), Value::StringArray(right))
349 .expect_err("size mismatch"),
350 );
351 assert!(err.contains(STRCMP_ERROR_SHAPE_MISMATCH.message));
352 }
353
354 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
355 #[test]
356 fn strcmp_invalid_argument_type() {
357 let err = error_message(
358 strcmp_builtin(Value::Num(1.0), Value::String("a".into())).expect_err("invalid type"),
359 );
360 assert!(err.contains(STRCMP_ERROR_INVALID_INPUT.message));
361 }
362
363 #[test]
364 fn strcmp_type_is_logical_match() {
365 assert_eq!(
366 logical_text_match_type(
367 &[Type::String, Type::String],
368 &ResolveContext::new(Vec::new()),
369 ),
370 Type::Bool
371 );
372 }
373}