1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
7use crate::builtins::common::map_control_flow_with_builtin;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::common::tensor;
13use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
14use crate::builtins::strings::type_resolvers::logical_text_match_type;
15use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
16
17#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strcmp")]
18pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
19 name: "strcmp",
20 op_kind: GpuOpKind::Custom("string-compare"),
21 supported_precisions: &[],
22 broadcast: BroadcastSemantics::Matlab,
23 provider_hooks: &[],
24 constant_strategy: ConstantStrategy::InlineLiteral,
25 residency: ResidencyPolicy::GatherImmediately,
26 nan_mode: ReductionNaN::Include,
27 two_pass_threshold: None,
28 workgroup_size: None,
29 accepts_nan_mode: false,
30 notes: "Performs host-side text comparisons; GPU operands are gathered automatically before evaluation.",
31};
32
33#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strcmp")]
34pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
35 name: "strcmp",
36 shape: ShapeRequirements::Any,
37 constant_strategy: ConstantStrategy::InlineLiteral,
38 elementwise: None,
39 reduction: None,
40 emits_nan: false,
41 notes: "Produces logical results on the host; not eligible for GPU fusion.",
42};
43
44#[allow(dead_code)]
45fn strcmp_flow(message: impl Into<String>) -> RuntimeError {
46 build_runtime_error(message).with_builtin("strcmp").build()
47}
48
49fn remap_strcmp_flow(err: RuntimeError) -> RuntimeError {
50 map_control_flow_with_builtin(err, "strcmp")
51}
52
53#[runtime_builtin(
54 name = "strcmp",
55 category = "strings/core",
56 summary = "Compare text inputs for exact matches (case-sensitive).",
57 keywords = "strcmp,string compare,text equality",
58 accel = "sink",
59 type_resolver(logical_text_match_type),
60 builtin_path = "crate::builtins::strings::core::strcmp"
61)]
62async fn strcmp_builtin(a: Value, b: Value) -> crate::BuiltinResult<Value> {
63 let a = gather_if_needed_async(&a)
64 .await
65 .map_err(remap_strcmp_flow)?;
66 let b = gather_if_needed_async(&b)
67 .await
68 .map_err(remap_strcmp_flow)?;
69 let left = TextCollection::from_argument("strcmp", a, "first argument")?;
70 let right = TextCollection::from_argument("strcmp", b, "second argument")?;
71 evaluate_strcmp(&left, &right)
72}
73
74fn evaluate_strcmp(left: &TextCollection, right: &TextCollection) -> BuiltinResult<Value> {
75 let shape = broadcast_shapes("strcmp", &left.shape, &right.shape)?;
76 let total = tensor::element_count(&shape);
77 if total == 0 {
78 return logical_result("strcmp", Vec::new(), shape);
79 }
80 let left_strides = compute_strides(&left.shape);
81 let right_strides = compute_strides(&right.shape);
82 let mut data = Vec::with_capacity(total);
83 for linear in 0..total {
84 let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
85 let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
86 let equal = match (&left.elements[li], &right.elements[ri]) {
87 (TextElement::Missing, _) => false,
88 (_, TextElement::Missing) => false,
89 (TextElement::Text(lhs), TextElement::Text(rhs)) => lhs == rhs,
90 };
91 data.push(if equal { 1 } else { 0 });
92 }
93 logical_result("strcmp", data, shape)
94}
95
96#[cfg(test)]
97pub(crate) mod tests {
98 use super::*;
99 use crate::RuntimeError;
100 use runmat_builtins::{CellArray, CharArray, LogicalArray, ResolveContext, StringArray, Type};
101
102 fn strcmp_builtin(a: Value, b: Value) -> BuiltinResult<Value> {
103 futures::executor::block_on(super::strcmp_builtin(a, b))
104 }
105
106 fn error_message(err: RuntimeError) -> String {
107 err.message().to_string()
108 }
109
110 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
111 #[test]
112 fn strcmp_string_scalar_true() {
113 let result = strcmp_builtin(
114 Value::String("RunMat".into()),
115 Value::String("RunMat".into()),
116 )
117 .expect("strcmp");
118 assert_eq!(result, Value::Bool(true));
119 }
120
121 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
122 #[test]
123 fn strcmp_string_scalar_false() {
124 let result = strcmp_builtin(
125 Value::String("RunMat".into()),
126 Value::String("runmat".into()),
127 )
128 .expect("strcmp");
129 assert_eq!(result, Value::Bool(false));
130 }
131
132 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
133 #[test]
134 fn strcmp_string_array_broadcast_scalar() {
135 let array = StringArray::new(
136 vec!["red".into(), "green".into(), "blue".into()],
137 vec![1, 3],
138 )
139 .unwrap();
140 let result =
141 strcmp_builtin(Value::StringArray(array), Value::String("green".into())).expect("cmp");
142 let expected = LogicalArray::new(vec![0, 1, 0], vec![1, 3]).unwrap();
143 assert_eq!(result, Value::LogicalArray(expected));
144 }
145
146 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
147 #[test]
148 fn strcmp_char_array_row_compare() {
149 let chars = CharArray::new(vec!['c', 'a', 't', 'd', 'o', 'g'], 2, 3).unwrap();
150 let result =
151 strcmp_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("cmp");
152 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
153 assert_eq!(result, Value::LogicalArray(expected));
154 }
155
156 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
157 #[test]
158 fn strcmp_char_array_to_char_array() {
159 let left = CharArray::new(vec!['a', 'b', 'c', 'd'], 2, 2).unwrap();
160 let right = CharArray::new(vec!['a', 'b', 'x', 'y'], 2, 2).unwrap();
161 let result =
162 strcmp_builtin(Value::CharArray(left), Value::CharArray(right)).expect("strcmp");
163 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
164 assert_eq!(result, Value::LogicalArray(expected));
165 }
166
167 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
168 #[test]
169 fn strcmp_cell_array_scalar() {
170 let cell = CellArray::new(
171 vec![
172 Value::from("apple"),
173 Value::from("pear"),
174 Value::from("grape"),
175 ],
176 1,
177 3,
178 )
179 .unwrap();
180 let result =
181 strcmp_builtin(Value::Cell(cell), Value::String("grape".into())).expect("strcmp");
182 let expected = LogicalArray::new(vec![0, 0, 1], vec![1, 3]).unwrap();
183 assert_eq!(result, Value::LogicalArray(expected));
184 }
185
186 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
187 #[test]
188 fn strcmp_cell_array_to_cell_array_broadcasts() {
189 let left = CellArray::new(vec![Value::from("red"), Value::from("blue")], 2, 1).unwrap();
190 let right = CellArray::new(vec![Value::from("red")], 1, 1).unwrap();
191 let result = strcmp_builtin(Value::Cell(left), Value::Cell(right)).expect("strcmp");
192 let expected = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
193 assert_eq!(result, Value::LogicalArray(expected));
194 }
195
196 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
197 #[test]
198 fn strcmp_string_array_multi_dimensional_broadcast() {
199 let left = StringArray::new(vec!["north".into(), "south".into()], vec![2, 1]).unwrap();
200 let right = StringArray::new(
201 vec!["north".into(), "east".into(), "south".into()],
202 vec![1, 3],
203 )
204 .unwrap();
205 let result =
206 strcmp_builtin(Value::StringArray(left), Value::StringArray(right)).expect("strcmp");
207 let expected = LogicalArray::new(vec![1, 0, 0, 0, 0, 1], vec![2, 3]).unwrap();
208 assert_eq!(result, Value::LogicalArray(expected));
209 }
210
211 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
212 #[test]
213 fn strcmp_char_array_trailing_space_is_not_equal() {
214 let chars = CharArray::new(vec!['c', 'a', 't', ' '], 1, 4).unwrap();
215 let result =
216 strcmp_builtin(Value::CharArray(chars), Value::String("cat".into())).expect("strcmp");
217 assert_eq!(result, Value::Bool(false));
218 }
219
220 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
221 #[test]
222 fn strcmp_char_array_empty_rows_returns_empty() {
223 let chars = CharArray::new(Vec::new(), 0, 0).unwrap();
224 let result = strcmp_builtin(Value::CharArray(chars), Value::String("anything".into()))
225 .expect("strcmp");
226 match result {
227 Value::LogicalArray(array) => {
228 assert_eq!(array.shape, vec![0, 1]);
229 assert!(array.data.is_empty());
230 }
231 other => panic!("expected empty logical array, got {other:?}"),
232 }
233 }
234
235 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
236 #[test]
237 fn strcmp_missing_strings_compare_false() {
238 let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
239 let result = strcmp_builtin(
240 Value::StringArray(strings.clone()),
241 Value::StringArray(strings),
242 )
243 .expect("strcmp");
244 assert_eq!(result, Value::Bool(false));
245 }
246
247 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
248 #[test]
249 fn strcmp_missing_string_false() {
250 let array = StringArray::new(vec!["alpha".into(), "<missing>".into()], vec![1, 2]).unwrap();
251 let result =
252 strcmp_builtin(Value::StringArray(array), Value::String("alpha".into())).expect("cmp");
253 let expected = LogicalArray::new(vec![1, 0], vec![1, 2]).unwrap();
254 assert_eq!(result, Value::LogicalArray(expected));
255 }
256
257 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
258 #[test]
259 fn strcmp_size_mismatch_error() {
260 let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
261 let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
262 let err = error_message(
263 strcmp_builtin(Value::StringArray(left), Value::StringArray(right))
264 .expect_err("size mismatch"),
265 );
266 assert!(err.contains("size mismatch"));
267 }
268
269 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
270 #[test]
271 fn strcmp_invalid_argument_type() {
272 let err = error_message(
273 strcmp_builtin(Value::Num(1.0), Value::String("a".into())).expect_err("invalid type"),
274 );
275 assert!(err.contains("first argument must be text"));
276 }
277
278 #[test]
279 fn strcmp_type_is_logical_match() {
280 assert_eq!(
281 logical_text_match_type(
282 &[Type::String, Type::String],
283 &ResolveContext::new(Vec::new()),
284 ),
285 Type::Bool
286 );
287 }
288}