use runmat_builtins::Value;
use runmat_macros::runtime_builtin;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ReductionNaN, ResidencyPolicy, ShapeRequirements,
};
use crate::builtins::common::tensor;
use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult};
use crate::builtins::common::broadcast::{broadcast_index, broadcast_shapes, compute_strides};
use super::text_utils::{logical_result, parse_ignore_case, TextCollection, TextElement};
use crate::builtins::strings::type_resolvers::logical_text_match_type;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::strings::search::endswith")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "endsWith",
op_kind: GpuOpKind::Custom("string-search"),
supported_precisions: &[],
broadcast: BroadcastSemantics::Matlab,
provider_hooks: &[],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::GatherImmediately,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Executes entirely on the host; inputs are gathered from the GPU before evaluating suffix checks.",
};
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::strings::search::endswith")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "endsWith",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: false,
notes: "Text operation; not eligible for fusion and materialises host logical results.",
};
const BUILTIN_NAME: &str = "endsWith";
#[runtime_builtin(
name = "endsWith",
category = "strings/search",
summary = "Return logical values indicating whether text inputs end with specific patterns.",
keywords = "endswith,suffix,text,ignorecase,search",
accel = "sink",
type_resolver(logical_text_match_type),
builtin_path = "crate::builtins::strings::search::endswith"
)]
async fn endswith_builtin(
text: Value,
pattern: Value,
rest: Vec<Value>,
) -> crate::BuiltinResult<Value> {
let text = gather_if_needed_async(&text).await?;
let pattern = gather_if_needed_async(&pattern).await?;
let mut option_args = Vec::with_capacity(rest.len());
for value in rest {
option_args.push(gather_if_needed_async(&value).await?);
}
let ignore_case = parse_ignore_case(BUILTIN_NAME, &option_args)?;
let subject = TextCollection::from_subject(BUILTIN_NAME, text)?;
let patterns = TextCollection::from_pattern(BUILTIN_NAME, pattern)?;
evaluate_endswith(&subject, &patterns, ignore_case)
}
fn evaluate_endswith(
subject: &TextCollection,
patterns: &TextCollection,
ignore_case: bool,
) -> BuiltinResult<Value> {
let output_shape = broadcast_shapes(BUILTIN_NAME, &subject.shape, &patterns.shape)
.map_err(|err| build_runtime_error(err).with_builtin(BUILTIN_NAME).build())?;
let total = tensor::element_count(&output_shape);
if total == 0 {
return logical_result(BUILTIN_NAME, Vec::new(), output_shape);
}
let subject_strides = compute_strides(&subject.shape);
let pattern_strides = compute_strides(&patterns.shape);
let subject_lower = if ignore_case {
Some(subject.lowercased())
} else {
None
};
let pattern_lower = if ignore_case {
Some(patterns.lowercased())
} else {
None
};
let mut data = Vec::with_capacity(total);
for linear in 0..total {
let subject_idx = broadcast_index(linear, &output_shape, &subject.shape, &subject_strides);
let pattern_idx = broadcast_index(linear, &output_shape, &patterns.shape, &pattern_strides);
let value = match (
&subject.elements[subject_idx],
&patterns.elements[pattern_idx],
) {
(TextElement::Missing, _) => false,
(_, TextElement::Missing) => false,
(TextElement::Text(text), TextElement::Text(pattern)) => {
if pattern.is_empty() {
true
} else if ignore_case {
let lowered_subject = subject_lower
.as_ref()
.and_then(|vec| vec[subject_idx].as_deref())
.expect("lowercase subject available");
let lowered_pattern = pattern_lower
.as_ref()
.and_then(|vec| vec[pattern_idx].as_deref())
.expect("lowercase pattern available");
lowered_subject.ends_with(lowered_pattern)
} else {
text.ends_with(pattern.as_str())
}
}
};
data.push(if value { 1 } else { 0 });
}
logical_result(BUILTIN_NAME, data, output_shape)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::builtins::common::test_support;
use runmat_accelerate_api::HostTensorView;
use runmat_builtins::{
CellArray, CharArray, IntValue, LogicalArray, ResolveContext, StringArray, Tensor, Type,
};
fn run_endswith(text: Value, pattern: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
futures::executor::block_on(endswith_builtin(text, pattern, rest))
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_string_scalar_true() {
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("Mat".into()),
Vec::new(),
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_string_scalar_false() {
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("Run".into()),
Vec::new(),
)
.expect("endsWith");
assert_eq!(result, Value::Bool(false));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_option() {
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![Value::String("IgnoreCase".into()), Value::Bool(true)],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_string_array_scalar_pattern() {
let array = StringArray::new(
vec!["alpha".into(), "beta".into(), "gamma".into()],
vec![3, 1],
)
.unwrap();
let result = run_endswith(
Value::StringArray(array),
Value::String("a".into()),
Vec::new(),
)
.expect("endsWith");
let expected = LogicalArray::new(vec![1, 1, 1], vec![3, 1]).unwrap();
assert_eq!(result, Value::LogicalArray(expected));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_elementwise_patterns() {
let subjects = StringArray::new(
vec!["hydrogen".into(), "helium".into(), "lithium".into()],
vec![3, 1],
)
.unwrap();
let patterns =
StringArray::new(vec!["gen".into(), "ium".into(), "ium".into()], vec![3, 1]).unwrap();
let result = run_endswith(
Value::StringArray(subjects),
Value::StringArray(patterns),
Vec::new(),
)
.expect("endsWith");
let expected = LogicalArray::new(vec![1, 1, 1], vec![3, 1]).unwrap();
assert_eq!(result, Value::LogicalArray(expected));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_broadcast_pattern_column_vector() {
let patterns = CharArray::new(vec!['n', 'x', 'r'], 3, 1).unwrap();
let result = run_endswith(
Value::String("saturn".into()),
Value::CharArray(patterns),
Vec::new(),
)
.expect("endsWith char");
let expected = LogicalArray::new(vec![1, 0, 0], vec![3, 1]).unwrap();
assert_eq!(result, Value::LogicalArray(expected));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_cell_array_patterns() {
let cell = CellArray::new(
vec![
Value::from("Mercury"),
Value::from("Venus"),
Value::from("Mars"),
],
1,
3,
)
.unwrap();
let result = run_endswith(Value::Cell(cell), Value::String("s".into()), Vec::new())
.expect("endsWith");
let expected = LogicalArray::new(vec![0, 1, 1], vec![1, 3]).unwrap();
assert_eq!(result, Value::LogicalArray(expected));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_missing_strings_false() {
let array = StringArray::new(vec!["<missing>".into()], vec![1, 1]).unwrap();
let result = run_endswith(
Value::StringArray(array),
Value::String("a".into()),
Vec::new(),
)
.expect("endsWith");
assert_eq!(result, Value::Bool(false));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_empty_pattern_true() {
let result = run_endswith(
Value::String("foo".into()),
Value::String("".into()),
Vec::new(),
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_invalid_option_name() {
let err = run_endswith(
Value::String("foo".into()),
Value::String("o".into()),
vec![Value::String("IgnoreCases".into()), Value::Bool(true)],
)
.unwrap_err();
assert!(err.to_string().contains("unknown option"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_string_flag() {
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::String("on".into()),
],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_numeric_flag() {
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::Int(IntValue::I32(0)),
],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(false));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_positional_value() {
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![Value::Bool(true)],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_logical_array_value() {
let logical = LogicalArray::new(vec![1], vec![1, 1]).unwrap();
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::LogicalArray(logical),
],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_tensor_value() {
let tensor = Tensor::new(vec![0.0], vec![1, 1]).unwrap();
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![Value::String("IgnoreCase".into()), Value::Tensor(tensor)],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(false));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_gpu_tensor_flag() {
test_support::with_test_provider(|provider| {
let data = [1.0];
let shape = [1usize, 1usize];
let handle = provider
.upload(&HostTensorView {
data: &data,
shape: &shape,
})
.expect("upload");
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::GpuTensor(handle.clone()),
],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
provider.free(&handle).expect("free gpu flag");
});
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_gpu_tensor_flag_wgpu() {
use runmat_accelerate::backend::wgpu::provider::{
register_wgpu_provider, WgpuProviderOptions,
};
if register_wgpu_provider(WgpuProviderOptions::default()).is_err() {
return;
}
let Some(provider) = runmat_accelerate_api::provider() else {
return;
};
let data = [1.0];
let shape = [1usize, 1usize];
let handle = provider
.upload(&HostTensorView {
data: &data,
shape: &shape,
})
.expect("upload");
let result = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::GpuTensor(handle.clone()),
],
)
.expect("endsWith");
assert_eq!(result, Value::Bool(true));
let _ = provider.free(&handle);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_invalid_value() {
let err = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::String("maybe".into()),
],
)
.unwrap_err();
assert!(err.to_string().contains("invalid value"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_logical_array_invalid_size() {
let logical = LogicalArray::new(vec![1, 0], vec![2, 1]).unwrap();
let err = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![
Value::String("IgnoreCase".into()),
Value::LogicalArray(logical),
],
)
.unwrap_err();
assert!(err.to_string().contains("scalar logicals"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_numeric_nan_invalid() {
let err = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![Value::Num(f64::NAN)],
)
.unwrap_err();
assert!(err.to_string().contains("finite scalar"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_ignore_case_missing_value() {
let err = run_endswith(
Value::String("RunMat".into()),
Value::String("mat".into()),
vec![Value::String("IgnoreCase".into())],
)
.unwrap_err();
assert!(err
.to_string()
.contains("expected a value after 'IgnoreCase'"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_mismatched_shapes_error() {
let text = StringArray::new(vec!["a".into(), "b".into()], vec![2, 1]).unwrap();
let pattern =
StringArray::new(vec!["a".into(), "b".into(), "c".into()], vec![3, 1]).unwrap();
let err = run_endswith(
Value::StringArray(text),
Value::StringArray(pattern),
Vec::new(),
)
.unwrap_err();
assert!(err.to_string().contains("size mismatch"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_invalid_subject_type() {
let err = run_endswith(Value::Num(1.0), Value::String("a".into()), Vec::new()).unwrap_err();
assert!(err.to_string().contains("first argument must be text"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_invalid_pattern_type() {
let err =
run_endswith(Value::String("foo".into()), Value::Num(1.0), Vec::new()).unwrap_err();
assert!(
err.to_string().contains("pattern must be text"),
"expected pattern type error, got: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_cell_invalid_element_error() {
let cell = CellArray::new(vec![Value::Num(1.0)], 1, 1).unwrap();
let err =
run_endswith(Value::Cell(cell), Value::String("a".into()), Vec::new()).unwrap_err();
assert!(err.to_string().contains("cell array elements"));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_zero_sized_inputs() {
let subjects = StringArray::new(Vec::<String>::new(), vec![0, 1]).unwrap();
let result = run_endswith(
Value::StringArray(subjects),
Value::String("a".into()),
Vec::new(),
)
.expect("endsWith");
match result {
Value::LogicalArray(array) => {
assert_eq!(array.shape, vec![0, 1]);
assert!(array.data.is_empty());
}
other => panic!("expected logical array, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn endswith_missing_pattern_false() {
let result = run_endswith(
Value::String("alpha".into()),
Value::String("<missing>".into()),
Vec::new(),
)
.expect("endsWith");
assert_eq!(result, Value::Bool(false));
}
#[test]
fn endswith_type_is_logical_match() {
assert_eq!(
logical_text_match_type(
&[Type::String, Type::String],
&ResolveContext::new(Vec::new()),
),
Type::Bool
);
}
}