1use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
7use crate::builtins::common::map_control_flow_with_builtin;
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::common::tensor;
13use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
14use crate::builtins::strings::type_resolvers::logical_text_match_type;
15use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
16
17const FN_NAME: &str = "strncmp";
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strncmp")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21 name: "strncmp",
22 op_kind: GpuOpKind::Custom("string-prefix-compare"),
23 supported_precisions: &[],
24 broadcast: BroadcastSemantics::Matlab,
25 provider_hooks: &[],
26 constant_strategy: ConstantStrategy::InlineLiteral,
27 residency: ResidencyPolicy::GatherImmediately,
28 nan_mode: ReductionNaN::Include,
29 two_pass_threshold: None,
30 workgroup_size: None,
31 accepts_nan_mode: false,
32 notes: "Performs host-side prefix comparisons; GPU inputs are gathered before evaluation.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strncmp")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37 name: "strncmp",
38 shape: ShapeRequirements::Any,
39 constant_strategy: ConstantStrategy::InlineLiteral,
40 elementwise: None,
41 reduction: None,
42 emits_nan: false,
43 notes: "Produces logical host results and is not eligible for GPU fusion.",
44};
45
46fn strncmp_flow(message: impl Into<String>) -> RuntimeError {
47 build_runtime_error(message).with_builtin(FN_NAME).build()
48}
49
50fn remap_strncmp_flow(err: RuntimeError) -> RuntimeError {
51 map_control_flow_with_builtin(err, FN_NAME)
52}
53
54#[runtime_builtin(
55 name = "strncmp",
56 category = "strings/core",
57 summary = "Compare text inputs for equality up to N leading characters (case-sensitive).",
58 keywords = "strncmp,string compare,prefix,text equality",
59 accel = "sink",
60 type_resolver(logical_text_match_type),
61 builtin_path = "crate::builtins::strings::core::strncmp"
62)]
63async fn strncmp_builtin(a: Value, b: Value, n: Value) -> crate::BuiltinResult<Value> {
64 let a = gather_if_needed_async(&a)
65 .await
66 .map_err(remap_strncmp_flow)?;
67 let b = gather_if_needed_async(&b)
68 .await
69 .map_err(remap_strncmp_flow)?;
70 let n = gather_if_needed_async(&n)
71 .await
72 .map_err(remap_strncmp_flow)?;
73
74 let limit = parse_prefix_length(n)?;
75 let left = TextCollection::from_argument(FN_NAME, a, "first argument")?;
76 let right = TextCollection::from_argument(FN_NAME, b, "second argument")?;
77 evaluate_strncmp(&left, &right, limit)
78}
79
80fn evaluate_strncmp(
81 left: &TextCollection,
82 right: &TextCollection,
83 limit: usize,
84) -> BuiltinResult<Value> {
85 let shape = broadcast_shapes(FN_NAME, &left.shape, &right.shape)?;
86 let total = tensor::element_count(&shape);
87 if total == 0 {
88 return logical_result(FN_NAME, Vec::new(), shape);
89 }
90
91 let left_strides = compute_strides(&left.shape);
92 let right_strides = compute_strides(&right.shape);
93 let mut data = Vec::with_capacity(total);
94
95 for linear in 0..total {
96 let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
97 let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
98 let equal = if limit == 0 {
99 true
100 } else {
101 match (&left.elements[li], &right.elements[ri]) {
102 (TextElement::Missing, _) | (_, TextElement::Missing) => false,
103 (TextElement::Text(lhs), TextElement::Text(rhs)) => prefix_equal(lhs, rhs, limit),
104 }
105 };
106 data.push(if equal { 1 } else { 0 });
107 }
108
109 logical_result(FN_NAME, data, shape)
110}
111
112fn prefix_equal(lhs: &str, rhs: &str, limit: usize) -> bool {
113 if limit == 0 {
114 return true;
115 }
116 let mut lhs_iter = lhs.chars();
117 let mut rhs_iter = rhs.chars();
118 let mut compared = 0usize;
119
120 while compared < limit {
121 let left_char = lhs_iter.next();
122 let right_char = rhs_iter.next();
123 match (left_char, right_char) {
124 (Some(lc), Some(rc)) => {
125 if lc != rc {
126 return false;
127 }
128 }
129 (None, Some(_)) | (Some(_), None) => {
130 return false;
131 }
132 (None, None) => {
133 return true;
134 }
135 }
136 compared += 1;
137 }
138
139 true
140}
141
142fn parse_prefix_length(value: Value) -> BuiltinResult<usize> {
143 match value {
144 Value::Int(i) => {
145 let raw = i.to_i64();
146 if raw < 0 {
147 return Err(strncmp_flow(format!(
148 "{FN_NAME}: prefix length must be a nonnegative integer"
149 )));
150 }
151 Ok(raw as usize)
152 }
153 Value::Num(n) => parse_prefix_length_from_float(n),
154 Value::Bool(b) => Ok(if b { 1 } else { 0 }),
155 Value::Tensor(tensor) => {
156 if tensor.data.len() != 1 {
157 return Err(strncmp_flow(format!(
158 "{FN_NAME}: prefix length must be a nonnegative integer scalar"
159 )));
160 }
161 parse_prefix_length_from_float(tensor.data[0])
162 }
163 Value::LogicalArray(array) => {
164 if array.data.len() != 1 {
165 return Err(strncmp_flow(format!(
166 "{FN_NAME}: prefix length must be a nonnegative integer scalar"
167 )));
168 }
169 Ok(if array.data[0] != 0 { 1 } else { 0 })
170 }
171 other => Err(strncmp_flow(format!(
172 "{FN_NAME}: prefix length must be a nonnegative integer scalar, received {other:?}"
173 ))),
174 }
175}
176
177fn parse_prefix_length_from_float(value: f64) -> BuiltinResult<usize> {
178 if !value.is_finite() {
179 return Err(strncmp_flow(format!(
180 "{FN_NAME}: prefix length must be a finite nonnegative integer"
181 )));
182 }
183 if value < 0.0 {
184 return Err(strncmp_flow(format!(
185 "{FN_NAME}: prefix length must be a nonnegative integer"
186 )));
187 }
188 let rounded = value.round();
189 if (rounded - value).abs() > f64::EPSILON {
190 return Err(strncmp_flow(format!(
191 "{FN_NAME}: prefix length must be a nonnegative integer"
192 )));
193 }
194 if rounded > (usize::MAX as f64) {
195 return Err(strncmp_flow(format!(
196 "{FN_NAME}: prefix length exceeds the maximum supported size"
197 )));
198 }
199 Ok(rounded as usize)
200}
201
202#[cfg(test)]
203pub(crate) mod tests {
204 use super::*;
205 #[cfg(feature = "wgpu")]
206 use runmat_accelerate_api::AccelProvider;
207 use runmat_builtins::{
208 CellArray, CharArray, IntValue, LogicalArray, ResolveContext, StringArray, Tensor, Type,
209 };
210
211 fn strncmp_builtin(a: Value, b: Value, n: Value) -> BuiltinResult<Value> {
212 futures::executor::block_on(super::strncmp_builtin(a, b, n))
213 }
214
215 fn error_message(err: crate::RuntimeError) -> String {
216 err.message().to_string()
217 }
218
219 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
220 #[test]
221 fn strncmp_exact_prefix_true() {
222 let result = strncmp_builtin(
223 Value::String("RunMat".into()),
224 Value::String("Runway".into()),
225 Value::Int(IntValue::I32(3)),
226 )
227 .expect("strncmp");
228 assert_eq!(result, Value::Bool(true));
229 }
230
231 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
232 #[test]
233 fn strncmp_mismatch_within_prefix_false() {
234 let result = strncmp_builtin(
235 Value::String("RunMat".into()),
236 Value::String("Runway".into()),
237 Value::Int(IntValue::I32(4)),
238 )
239 .expect("strncmp");
240 assert_eq!(result, Value::Bool(false));
241 }
242
243 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
244 #[test]
245 fn strncmp_longer_string_after_prefix_false() {
246 let result = strncmp_builtin(
247 Value::String("cat".into()),
248 Value::String("cater".into()),
249 Value::Int(IntValue::I32(4)),
250 )
251 .expect("strncmp");
252 assert_eq!(result, Value::Bool(false));
253 }
254
255 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
256 #[test]
257 fn strncmp_zero_length_always_true() {
258 let result = strncmp_builtin(
259 Value::String("alpha".into()),
260 Value::String("omega".into()),
261 Value::Num(0.0),
262 )
263 .expect("strncmp");
264 assert_eq!(result, Value::Bool(true));
265 }
266
267 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
268 #[test]
269 fn strncmp_prefix_length_bool_true_compares_first_character() {
270 let result = strncmp_builtin(
271 Value::String("alpha".into()),
272 Value::String("array".into()),
273 Value::Bool(true),
274 )
275 .expect("strncmp");
276 assert_eq!(result, Value::Bool(true));
277 }
278
279 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
280 #[test]
281 fn strncmp_prefix_length_bool_false_treated_as_zero() {
282 let result = strncmp_builtin(
283 Value::String("alpha".into()),
284 Value::String("omega".into()),
285 Value::Bool(false),
286 )
287 .expect("strncmp");
288 assert_eq!(result, Value::Bool(true));
289 }
290
291 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
292 #[test]
293 fn strncmp_prefix_length_logical_array_scalar() {
294 let logical = LogicalArray::new(vec![1], vec![1]).unwrap();
295 let result = strncmp_builtin(
296 Value::String("beta".into()),
297 Value::String("theta".into()),
298 Value::LogicalArray(logical),
299 )
300 .expect("strncmp");
301 assert_eq!(result, Value::Bool(false));
302 }
303
304 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
305 #[test]
306 fn strncmp_prefix_length_tensor_scalar_double() {
307 let limit = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
308 let result = strncmp_builtin(
309 Value::String("gamma".into()),
310 Value::String("gamut".into()),
311 Value::Tensor(limit),
312 )
313 .expect("strncmp");
314 assert_eq!(result, Value::Bool(true));
315 }
316
317 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
318 #[test]
319 fn strncmp_char_array_rows() {
320 let chars = CharArray::new(
321 vec![
322 'c', 'a', 't', ' ', ' ', 'c', 'a', 'm', 'e', 'l', 'c', 'o', 'w', ' ', ' ',
323 ],
324 3,
325 5,
326 )
327 .unwrap();
328 let result = strncmp_builtin(
329 Value::CharArray(chars),
330 Value::String("ca".into()),
331 Value::Int(IntValue::I32(2)),
332 )
333 .expect("strncmp");
334 let expected = LogicalArray::new(vec![1, 1, 0], vec![3, 1]).unwrap();
335 assert_eq!(result, Value::LogicalArray(expected));
336 }
337
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339 #[test]
340 fn strncmp_cell_arrays_broadcast() {
341 let left = CellArray::new(
342 vec![
343 Value::from("red"),
344 Value::from("green"),
345 Value::from("blue"),
346 ],
347 1,
348 3,
349 )
350 .unwrap();
351 let right = CellArray::new(
352 vec![
353 Value::from("rose"),
354 Value::from("gray"),
355 Value::from("black"),
356 ],
357 1,
358 3,
359 )
360 .unwrap();
361 let result = strncmp_builtin(
362 Value::Cell(left),
363 Value::Cell(right),
364 Value::Int(IntValue::I32(2)),
365 )
366 .expect("strncmp");
367 let expected = LogicalArray::new(vec![0, 1, 1], vec![1, 3]).unwrap();
368 assert_eq!(result, Value::LogicalArray(expected));
369 }
370
371 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
372 #[test]
373 fn strncmp_string_array_broadcast_scalar() {
374 let strings = StringArray::new(
375 vec!["north".into(), "south".into(), "east".into()],
376 vec![1, 3],
377 )
378 .unwrap();
379 let result = strncmp_builtin(
380 Value::StringArray(strings),
381 Value::String("no".into()),
382 Value::Int(IntValue::I32(2)),
383 )
384 .expect("strncmp");
385 let expected = LogicalArray::new(vec![1, 0, 0], vec![1, 3]).unwrap();
386 assert_eq!(result, Value::LogicalArray(expected));
387 }
388
389 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
390 #[test]
391 fn strncmp_missing_string_false_when_prefix_positive() {
392 let strings =
393 StringArray::new(vec!["<missing>".into(), "value".into()], vec![1, 2]).unwrap();
394 let result = strncmp_builtin(
395 Value::StringArray(strings),
396 Value::String("val".into()),
397 Value::Int(IntValue::I32(3)),
398 )
399 .expect("strncmp");
400 let expected = LogicalArray::new(vec![0, 1], vec![1, 2]).unwrap();
401 assert_eq!(result, Value::LogicalArray(expected));
402 }
403
404 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
405 #[test]
406 fn strncmp_missing_zero_length_true() {
407 let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
408 let result = strncmp_builtin(
409 Value::StringArray(strings),
410 Value::String("anything".into()),
411 Value::Int(IntValue::I32(0)),
412 )
413 .expect("strncmp");
414 assert_eq!(result, Value::Bool(true));
415 }
416
417 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418 #[test]
419 fn strncmp_size_mismatch_error() {
420 let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
421 let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
422 let err = error_message(
423 strncmp_builtin(
424 Value::StringArray(left),
425 Value::StringArray(right),
426 Value::Int(IntValue::I32(1)),
427 )
428 .expect_err("size mismatch"),
429 );
430 assert!(err.contains("size mismatch"));
431 }
432
433 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
434 #[test]
435 fn strncmp_invalid_length_type_errors() {
436 let err = error_message(
437 strncmp_builtin(
438 Value::String("abc".into()),
439 Value::String("abc".into()),
440 Value::String("3".into()),
441 )
442 .expect_err("invalid prefix length"),
443 );
444 assert!(err.contains("prefix length"));
445 }
446
447 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
448 #[test]
449 fn strncmp_negative_length_errors() {
450 let err = error_message(
451 strncmp_builtin(
452 Value::String("abc".into()),
453 Value::String("abc".into()),
454 Value::Num(-1.0),
455 )
456 .expect_err("negative length"),
457 );
458 assert!(err.to_ascii_lowercase().contains("nonnegative"));
459 }
460
461 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
462 #[test]
463 #[cfg(feature = "wgpu")]
464 fn strncmp_prefix_length_from_gpu_tensor() {
465 use runmat_accelerate::backend::wgpu::provider::{
466 register_wgpu_provider, WgpuProviderOptions,
467 };
468 use runmat_accelerate_api::HostTensorView;
469
470 let provider = match register_wgpu_provider(WgpuProviderOptions::default()) {
471 Ok(provider) => provider,
472 Err(_) => return,
473 };
474 let tensor = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
475 let view = HostTensorView {
476 data: &tensor.data,
477 shape: &tensor.shape,
478 };
479 let handle = provider.upload(&view).expect("upload prefix length to GPU");
480 let result = strncmp_builtin(
481 Value::String("delta".into()),
482 Value::String("deluge".into()),
483 Value::GpuTensor(handle.clone()),
484 )
485 .expect("strncmp");
486 assert_eq!(result, Value::Bool(true));
487 let _ = provider.free(&handle);
488 }
489
490 #[test]
491 fn strncmp_type_is_logical_match() {
492 assert_eq!(
493 logical_text_match_type(
494 &[Type::String, Type::String],
495 &ResolveContext::new(Vec::new()),
496 ),
497 Type::Bool
498 );
499 }
500}