Skip to main content

runmat_runtime/builtins/strings/core/
strncmp.rs

1//! MATLAB-compatible `strncmp` builtin for RunMat.
2
3use runmat_builtins::Value;
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
7use crate::builtins::common::map_control_flow_with_builtin;
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::builtins::common::tensor;
13use crate::builtins::strings::search::text_utils::{logical_result, TextCollection, TextElement};
14use crate::builtins::strings::type_resolvers::logical_text_match_type;
15use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
16
17const FN_NAME: &str = "strncmp";
18
19#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::core::strncmp")]
20pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
21    name: "strncmp",
22    op_kind: GpuOpKind::Custom("string-prefix-compare"),
23    supported_precisions: &[],
24    broadcast: BroadcastSemantics::Matlab,
25    provider_hooks: &[],
26    constant_strategy: ConstantStrategy::InlineLiteral,
27    residency: ResidencyPolicy::GatherImmediately,
28    nan_mode: ReductionNaN::Include,
29    two_pass_threshold: None,
30    workgroup_size: None,
31    accepts_nan_mode: false,
32    notes: "Performs host-side prefix comparisons; GPU inputs are gathered before evaluation.",
33};
34
35#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::core::strncmp")]
36pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
37    name: "strncmp",
38    shape: ShapeRequirements::Any,
39    constant_strategy: ConstantStrategy::InlineLiteral,
40    elementwise: None,
41    reduction: None,
42    emits_nan: false,
43    notes: "Produces logical host results and is not eligible for GPU fusion.",
44};
45
46fn strncmp_flow(message: impl Into<String>) -> RuntimeError {
47    build_runtime_error(message).with_builtin(FN_NAME).build()
48}
49
50fn remap_strncmp_flow(err: RuntimeError) -> RuntimeError {
51    map_control_flow_with_builtin(err, FN_NAME)
52}
53
54#[runtime_builtin(
55    name = "strncmp",
56    category = "strings/core",
57    summary = "Compare text inputs for equality up to N leading characters (case-sensitive).",
58    keywords = "strncmp,string compare,prefix,text equality",
59    accel = "sink",
60    type_resolver(logical_text_match_type),
61    builtin_path = "crate::builtins::strings::core::strncmp"
62)]
63async fn strncmp_builtin(a: Value, b: Value, n: Value) -> crate::BuiltinResult<Value> {
64    let a = gather_if_needed_async(&a)
65        .await
66        .map_err(remap_strncmp_flow)?;
67    let b = gather_if_needed_async(&b)
68        .await
69        .map_err(remap_strncmp_flow)?;
70    let n = gather_if_needed_async(&n)
71        .await
72        .map_err(remap_strncmp_flow)?;
73
74    let limit = parse_prefix_length(n)?;
75    let left = TextCollection::from_argument(FN_NAME, a, "first argument")?;
76    let right = TextCollection::from_argument(FN_NAME, b, "second argument")?;
77    evaluate_strncmp(&left, &right, limit)
78}
79
80fn evaluate_strncmp(
81    left: &TextCollection,
82    right: &TextCollection,
83    limit: usize,
84) -> BuiltinResult<Value> {
85    let shape = broadcast_shapes(FN_NAME, &left.shape, &right.shape)?;
86    let total = tensor::element_count(&shape);
87    if total == 0 {
88        return logical_result(FN_NAME, Vec::new(), shape);
89    }
90
91    let left_strides = compute_strides(&left.shape);
92    let right_strides = compute_strides(&right.shape);
93    let mut data = Vec::with_capacity(total);
94
95    for linear in 0..total {
96        let li = broadcast_index(linear, &shape, &left.shape, &left_strides);
97        let ri = broadcast_index(linear, &shape, &right.shape, &right_strides);
98        let equal = if limit == 0 {
99            true
100        } else {
101            match (&left.elements[li], &right.elements[ri]) {
102                (TextElement::Missing, _) | (_, TextElement::Missing) => false,
103                (TextElement::Text(lhs), TextElement::Text(rhs)) => prefix_equal(lhs, rhs, limit),
104            }
105        };
106        data.push(if equal { 1 } else { 0 });
107    }
108
109    logical_result(FN_NAME, data, shape)
110}
111
112fn prefix_equal(lhs: &str, rhs: &str, limit: usize) -> bool {
113    if limit == 0 {
114        return true;
115    }
116    let mut lhs_iter = lhs.chars();
117    let mut rhs_iter = rhs.chars();
118    let mut compared = 0usize;
119
120    while compared < limit {
121        let left_char = lhs_iter.next();
122        let right_char = rhs_iter.next();
123        match (left_char, right_char) {
124            (Some(lc), Some(rc)) => {
125                if lc != rc {
126                    return false;
127                }
128            }
129            (None, Some(_)) | (Some(_), None) => {
130                return false;
131            }
132            (None, None) => {
133                return true;
134            }
135        }
136        compared += 1;
137    }
138
139    true
140}
141
142fn parse_prefix_length(value: Value) -> BuiltinResult<usize> {
143    match value {
144        Value::Int(i) => {
145            let raw = i.to_i64();
146            if raw < 0 {
147                return Err(strncmp_flow(format!(
148                    "{FN_NAME}: prefix length must be a nonnegative integer"
149                )));
150            }
151            Ok(raw as usize)
152        }
153        Value::Num(n) => parse_prefix_length_from_float(n),
154        Value::Bool(b) => Ok(if b { 1 } else { 0 }),
155        Value::Tensor(tensor) => {
156            if tensor.data.len() != 1 {
157                return Err(strncmp_flow(format!(
158                    "{FN_NAME}: prefix length must be a nonnegative integer scalar"
159                )));
160            }
161            parse_prefix_length_from_float(tensor.data[0])
162        }
163        Value::LogicalArray(array) => {
164            if array.data.len() != 1 {
165                return Err(strncmp_flow(format!(
166                    "{FN_NAME}: prefix length must be a nonnegative integer scalar"
167                )));
168            }
169            Ok(if array.data[0] != 0 { 1 } else { 0 })
170        }
171        other => Err(strncmp_flow(format!(
172            "{FN_NAME}: prefix length must be a nonnegative integer scalar, received {other:?}"
173        ))),
174    }
175}
176
177fn parse_prefix_length_from_float(value: f64) -> BuiltinResult<usize> {
178    if !value.is_finite() {
179        return Err(strncmp_flow(format!(
180            "{FN_NAME}: prefix length must be a finite nonnegative integer"
181        )));
182    }
183    if value < 0.0 {
184        return Err(strncmp_flow(format!(
185            "{FN_NAME}: prefix length must be a nonnegative integer"
186        )));
187    }
188    let rounded = value.round();
189    if (rounded - value).abs() > f64::EPSILON {
190        return Err(strncmp_flow(format!(
191            "{FN_NAME}: prefix length must be a nonnegative integer"
192        )));
193    }
194    if rounded > (usize::MAX as f64) {
195        return Err(strncmp_flow(format!(
196            "{FN_NAME}: prefix length exceeds the maximum supported size"
197        )));
198    }
199    Ok(rounded as usize)
200}
201
202#[cfg(test)]
203pub(crate) mod tests {
204    use super::*;
205    #[cfg(feature = "wgpu")]
206    use runmat_accelerate_api::AccelProvider;
207    use runmat_builtins::{
208        CellArray, CharArray, IntValue, LogicalArray, ResolveContext, StringArray, Tensor, Type,
209    };
210
211    fn strncmp_builtin(a: Value, b: Value, n: Value) -> BuiltinResult<Value> {
212        futures::executor::block_on(super::strncmp_builtin(a, b, n))
213    }
214
215    fn error_message(err: crate::RuntimeError) -> String {
216        err.message().to_string()
217    }
218
219    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
220    #[test]
221    fn strncmp_exact_prefix_true() {
222        let result = strncmp_builtin(
223            Value::String("RunMat".into()),
224            Value::String("Runway".into()),
225            Value::Int(IntValue::I32(3)),
226        )
227        .expect("strncmp");
228        assert_eq!(result, Value::Bool(true));
229    }
230
231    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
232    #[test]
233    fn strncmp_mismatch_within_prefix_false() {
234        let result = strncmp_builtin(
235            Value::String("RunMat".into()),
236            Value::String("Runway".into()),
237            Value::Int(IntValue::I32(4)),
238        )
239        .expect("strncmp");
240        assert_eq!(result, Value::Bool(false));
241    }
242
243    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
244    #[test]
245    fn strncmp_longer_string_after_prefix_false() {
246        let result = strncmp_builtin(
247            Value::String("cat".into()),
248            Value::String("cater".into()),
249            Value::Int(IntValue::I32(4)),
250        )
251        .expect("strncmp");
252        assert_eq!(result, Value::Bool(false));
253    }
254
255    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
256    #[test]
257    fn strncmp_zero_length_always_true() {
258        let result = strncmp_builtin(
259            Value::String("alpha".into()),
260            Value::String("omega".into()),
261            Value::Num(0.0),
262        )
263        .expect("strncmp");
264        assert_eq!(result, Value::Bool(true));
265    }
266
267    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
268    #[test]
269    fn strncmp_prefix_length_bool_true_compares_first_character() {
270        let result = strncmp_builtin(
271            Value::String("alpha".into()),
272            Value::String("array".into()),
273            Value::Bool(true),
274        )
275        .expect("strncmp");
276        assert_eq!(result, Value::Bool(true));
277    }
278
279    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
280    #[test]
281    fn strncmp_prefix_length_bool_false_treated_as_zero() {
282        let result = strncmp_builtin(
283            Value::String("alpha".into()),
284            Value::String("omega".into()),
285            Value::Bool(false),
286        )
287        .expect("strncmp");
288        assert_eq!(result, Value::Bool(true));
289    }
290
291    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
292    #[test]
293    fn strncmp_prefix_length_logical_array_scalar() {
294        let logical = LogicalArray::new(vec![1], vec![1]).unwrap();
295        let result = strncmp_builtin(
296            Value::String("beta".into()),
297            Value::String("theta".into()),
298            Value::LogicalArray(logical),
299        )
300        .expect("strncmp");
301        assert_eq!(result, Value::Bool(false));
302    }
303
304    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
305    #[test]
306    fn strncmp_prefix_length_tensor_scalar_double() {
307        let limit = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
308        let result = strncmp_builtin(
309            Value::String("gamma".into()),
310            Value::String("gamut".into()),
311            Value::Tensor(limit),
312        )
313        .expect("strncmp");
314        assert_eq!(result, Value::Bool(true));
315    }
316
317    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
318    #[test]
319    fn strncmp_char_array_rows() {
320        let chars = CharArray::new(
321            vec![
322                'c', 'a', 't', ' ', ' ', 'c', 'a', 'm', 'e', 'l', 'c', 'o', 'w', ' ', ' ',
323            ],
324            3,
325            5,
326        )
327        .unwrap();
328        let result = strncmp_builtin(
329            Value::CharArray(chars),
330            Value::String("ca".into()),
331            Value::Int(IntValue::I32(2)),
332        )
333        .expect("strncmp");
334        let expected = LogicalArray::new(vec![1, 1, 0], vec![3, 1]).unwrap();
335        assert_eq!(result, Value::LogicalArray(expected));
336    }
337
338    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339    #[test]
340    fn strncmp_cell_arrays_broadcast() {
341        let left = CellArray::new(
342            vec![
343                Value::from("red"),
344                Value::from("green"),
345                Value::from("blue"),
346            ],
347            1,
348            3,
349        )
350        .unwrap();
351        let right = CellArray::new(
352            vec![
353                Value::from("rose"),
354                Value::from("gray"),
355                Value::from("black"),
356            ],
357            1,
358            3,
359        )
360        .unwrap();
361        let result = strncmp_builtin(
362            Value::Cell(left),
363            Value::Cell(right),
364            Value::Int(IntValue::I32(2)),
365        )
366        .expect("strncmp");
367        let expected = LogicalArray::new(vec![0, 1, 1], vec![1, 3]).unwrap();
368        assert_eq!(result, Value::LogicalArray(expected));
369    }
370
371    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
372    #[test]
373    fn strncmp_string_array_broadcast_scalar() {
374        let strings = StringArray::new(
375            vec!["north".into(), "south".into(), "east".into()],
376            vec![1, 3],
377        )
378        .unwrap();
379        let result = strncmp_builtin(
380            Value::StringArray(strings),
381            Value::String("no".into()),
382            Value::Int(IntValue::I32(2)),
383        )
384        .expect("strncmp");
385        let expected = LogicalArray::new(vec![1, 0, 0], vec![1, 3]).unwrap();
386        assert_eq!(result, Value::LogicalArray(expected));
387    }
388
389    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
390    #[test]
391    fn strncmp_missing_string_false_when_prefix_positive() {
392        let strings =
393            StringArray::new(vec!["<missing>".into(), "value".into()], vec![1, 2]).unwrap();
394        let result = strncmp_builtin(
395            Value::StringArray(strings),
396            Value::String("val".into()),
397            Value::Int(IntValue::I32(3)),
398        )
399        .expect("strncmp");
400        let expected = LogicalArray::new(vec![0, 1], vec![1, 2]).unwrap();
401        assert_eq!(result, Value::LogicalArray(expected));
402    }
403
404    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
405    #[test]
406    fn strncmp_missing_zero_length_true() {
407        let strings = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
408        let result = strncmp_builtin(
409            Value::StringArray(strings),
410            Value::String("anything".into()),
411            Value::Int(IntValue::I32(0)),
412        )
413        .expect("strncmp");
414        assert_eq!(result, Value::Bool(true));
415    }
416
417    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418    #[test]
419    fn strncmp_size_mismatch_error() {
420        let left = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
421        let right = StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
422        let err = error_message(
423            strncmp_builtin(
424                Value::StringArray(left),
425                Value::StringArray(right),
426                Value::Int(IntValue::I32(1)),
427            )
428            .expect_err("size mismatch"),
429        );
430        assert!(err.contains("size mismatch"));
431    }
432
433    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
434    #[test]
435    fn strncmp_invalid_length_type_errors() {
436        let err = error_message(
437            strncmp_builtin(
438                Value::String("abc".into()),
439                Value::String("abc".into()),
440                Value::String("3".into()),
441            )
442            .expect_err("invalid prefix length"),
443        );
444        assert!(err.contains("prefix length"));
445    }
446
447    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
448    #[test]
449    fn strncmp_negative_length_errors() {
450        let err = error_message(
451            strncmp_builtin(
452                Value::String("abc".into()),
453                Value::String("abc".into()),
454                Value::Num(-1.0),
455            )
456            .expect_err("negative length"),
457        );
458        assert!(err.to_ascii_lowercase().contains("nonnegative"));
459    }
460
461    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
462    #[test]
463    #[cfg(feature = "wgpu")]
464    fn strncmp_prefix_length_from_gpu_tensor() {
465        use runmat_accelerate::backend::wgpu::provider::{
466            register_wgpu_provider, WgpuProviderOptions,
467        };
468        use runmat_accelerate_api::HostTensorView;
469
470        let provider = match register_wgpu_provider(WgpuProviderOptions::default()) {
471            Ok(provider) => provider,
472            Err(_) => return,
473        };
474        let tensor = Tensor::new(vec![3.0], vec![1, 1]).unwrap();
475        let view = HostTensorView {
476            data: &tensor.data,
477            shape: &tensor.shape,
478        };
479        let handle = provider.upload(&view).expect("upload prefix length to GPU");
480        let result = strncmp_builtin(
481            Value::String("delta".into()),
482            Value::String("deluge".into()),
483            Value::GpuTensor(handle.clone()),
484        )
485        .expect("strncmp");
486        assert_eq!(result, Value::Bool(true));
487        let _ = provider.free(&handle);
488    }
489
490    #[test]
491    fn strncmp_type_is_logical_match() {
492        assert_eq!(
493            logical_text_match_type(
494                &[Type::String, Type::String],
495                &ResolveContext::new(Vec::new()),
496            ),
497            Type::Bool
498        );
499    }
500}